diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java
index be9914a3875..a85c033cedc 100644
--- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java
+++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java
@@ -456,6 +456,35 @@ public abstract class DataBufferUtils {
consumer::accept, new DataBufferMapper(bufferFactory), executor, chunkSize);
}
+ /**
+ * Subscribes to given {@link Publisher} and returns subscription
+ * as {@link InputStream} that allows reading all propagated {@link DataBuffer} messages via its imperative API.
+ * Given the {@link InputStream} implementation buffers messages as per configuration.
+ * The returned {@link InputStream} is considered terminated when the given {@link Publisher} signaled one of the
+ * terminal signal ({@link Subscriber#onComplete() or {@link Subscriber#onError(Throwable)}})
+ * and all the stored {@link DataBuffer} polled from the internal buffer.
+ * The returned {@link InputStream} will call {@link Subscription#cancel()} and release all stored {@link DataBuffer}
+ * when {@link InputStream#close()} is called.
+ *
+ * Note: The implementation of the returned {@link InputStream} disallow concurrent call on
+ * any of the {@link InputStream#read} methods
+ *
+ * Note: {@link Subscription#request(long)} happens eagerly for the first time upon subscription
+ * and then repeats every time {@code bufferSize - (bufferSize >> 2)} consumed
+ *
+ * @param publisher the source of {@link DataBuffer} which should be represented as an {@link InputStream}
+ * @param bufferSize the maximum amount of {@link DataBuffer} prefetched in advance and stored inside {@link InputStream}
+ * @return an {@link InputStream} instance representing given {@link Publisher} messages
+ */
+ public static InputStream subscribeAsInputStream(Publisher publisher, int bufferSize) {
+ Assert.notNull(publisher, "Publisher must not be null");
+ Assert.isTrue(bufferSize > 0, "Buffer size must be > 0");
+
+ InputStreamSubscriber inputStreamSubscriber = new InputStreamSubscriber(bufferSize);
+ publisher.subscribe(inputStreamSubscriber);
+ return inputStreamSubscriber;
+ }
+
//---------------------------------------------------------------------
// Various
diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/InputStreamSubscriber.java b/spring-core/src/main/java/org/springframework/core/io/buffer/InputStreamSubscriber.java
new file mode 100644
index 00000000000..b364927d953
--- /dev/null
+++ b/spring-core/src/main/java/org/springframework/core/io/buffer/InputStreamSubscriber.java
@@ -0,0 +1,355 @@
+package org.springframework.core.io.buffer;
+
+import org.reactivestreams.Publisher;
+import org.reactivestreams.Subscriber;
+import org.reactivestreams.Subscription;
+import org.springframework.lang.Nullable;
+import reactor.core.Exceptions;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.util.ConcurrentModificationException;
+import java.util.Objects;
+import java.util.Queue;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.locks.LockSupport;
+import java.util.concurrent.locks.ReentrantLock;
+
+/**
+ * Bridges between {@link Publisher Publisher<DataBuffer>} and {@link InputStream}.
+ *
+ * Note that this class has a near duplicate in
+ * {@link org.springframework.http.client.InputStreamSubscriber}.
+ *
+ * @author Oleh Dokuka
+ * @since 6.1
+ */
+final class InputStreamSubscriber extends InputStream implements Subscriber {
+
+ static final Object READY = new Object();
+ static final DataBuffer DONE = DefaultDataBuffer.fromEmptyByteBuffer(DefaultDataBufferFactory.sharedInstance, ByteBuffer.allocate(0));
+ static final DataBuffer CLOSED = DefaultDataBuffer.fromEmptyByteBuffer(DefaultDataBufferFactory.sharedInstance, ByteBuffer.allocate(0));
+
+ final int prefetch;
+ final int limit;
+ final ReentrantLock lock;
+ final Queue queue;
+
+ final AtomicReference parkedThread = new AtomicReference<>();
+ final AtomicInteger workAmount = new AtomicInteger();
+
+ volatile boolean closed;
+ int consumed;
+
+ @Nullable
+ DataBuffer available;
+
+ @Nullable
+ Subscription s;
+ boolean done;
+ @Nullable
+ Throwable error;
+
+ InputStreamSubscriber(int prefetch) {
+ this.prefetch = prefetch;
+ this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : prefetch - (prefetch >> 2);
+ this.queue = new ArrayBlockingQueue<>(prefetch);
+ this.lock = new ReentrantLock(false);
+ }
+
+ @Override
+ public void onSubscribe(Subscription subscription) {
+ if (this.s != null) {
+ subscription.cancel();
+ return;
+ }
+
+ this.s = subscription;
+ subscription.request(prefetch == Integer.MAX_VALUE ? Long.MAX_VALUE : prefetch);
+ }
+
+ @Override
+ public void onNext(DataBuffer t) {
+ if (this.done) {
+ discard(t);
+ return;
+ }
+
+ if (!queue.offer(t)) {
+ discard(t);
+ error = new RuntimeException("Buffer overflow");
+ done = true;
+ }
+
+ int previousWorkState = addWork();
+ if (previousWorkState == Integer.MIN_VALUE) {
+ DataBuffer value = queue.poll();
+ if (value != null) {
+ discard(value);
+ }
+ return;
+ }
+
+ if (previousWorkState == 0) {
+ resume();
+ }
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ if (this.done) {
+ return;
+ }
+ this.error = throwable;
+ this.done = true;
+
+ if (addWork() == 0) {
+ resume();
+ }
+ }
+
+ @Override
+ public void onComplete() {
+ if (this.done) {
+ return;
+ }
+
+ this.done = true;
+
+ if (addWork() == 0) {
+ resume();
+ }
+ }
+
+ int addWork() {
+ for (;;) {
+ int produced = this.workAmount.getPlain();
+
+ if (produced == Integer.MIN_VALUE) {
+ return Integer.MIN_VALUE;
+ }
+
+ int nextProduced = produced == Integer.MAX_VALUE ? 1 : produced + 1;
+
+
+ if (workAmount.weakCompareAndSetRelease(produced, nextProduced)) {
+ return produced;
+ }
+ }
+ }
+
+ @Override
+ public int read() throws IOException {
+ if (!lock.tryLock()) {
+ if (this.closed) {
+ return -1;
+ }
+ throw new ConcurrentModificationException("concurrent access is disallowed");
+ }
+
+ try {
+ DataBuffer bytes = getBytesOrAwait();
+
+ if (bytes == DONE) {
+ this.closed = true;
+ cleanAndFinalize();
+ if (this.error == null) {
+ return -1;
+ }
+ else {
+ throw Exceptions.propagate(error);
+ }
+ } else if (bytes == CLOSED) {
+ cleanAndFinalize();
+ return -1;
+ }
+
+ return bytes.read() & 0xFF;
+ }
+ catch (Throwable t) {
+ this.closed = true;
+ this.s.cancel();
+ cleanAndFinalize();
+ throw Exceptions.propagate(t);
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ Objects.checkFromIndexSize(off, len, b.length);
+ if (len == 0) {
+ return 0;
+ }
+
+ if (!lock.tryLock()) {
+ if (this.closed) {
+ return -1;
+ }
+ throw new ConcurrentModificationException("concurrent access is disallowed");
+ }
+
+ try {
+ for (int j = 0; j < len;) {
+ DataBuffer bytes = getBytesOrAwait();
+
+ if (bytes == DONE) {
+ cleanAndFinalize();
+ if (this.error == null) {
+ this.closed = true;
+ return j == 0 ? -1 : j;
+ }
+ else {
+ if (j == 0) {
+ this.closed = true;
+ throw Exceptions.propagate(error);
+ }
+
+ return j;
+ }
+ } else if (bytes == CLOSED) {
+ this.s.cancel();
+ cleanAndFinalize();
+ return -1;
+ }
+ int initialReadPosition = bytes.readPosition();
+ bytes.read(b, off + j, Math.min(len - j, bytes.readableByteCount()));
+ j += bytes.readPosition() - initialReadPosition;
+ }
+
+ return len;
+ }
+ catch (Throwable t) {
+ this.closed = true;
+ this.s.cancel();
+ cleanAndFinalize();
+ throw Exceptions.propagate(t);
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ DataBuffer getBytesOrAwait() {
+ if (this.available == null || this.available.readableByteCount() == 0) {
+
+ discard(this.available);
+ this.available = null;
+
+ int actualWorkAmount = this.workAmount.getAcquire();
+ for (;;) {
+ if (this.closed) {
+ return CLOSED;
+ }
+
+ boolean d = this.done;
+ DataBuffer t = this.queue.poll();
+ if (t != null) {
+ int consumed = ++this.consumed;
+ this.available = t;
+ if (consumed == this.limit) {
+ this.consumed = 0;
+ this.s.request(this.limit);
+ }
+ break;
+ }
+
+ if (d) {
+ return DONE;
+ }
+
+ actualWorkAmount = workAmount.addAndGet(-actualWorkAmount);
+ if (actualWorkAmount == 0) {
+ await();
+ }
+ }
+ }
+
+ return this.available;
+ }
+
+ void cleanAndFinalize() {
+ discard(this.available);
+ this.available = null;
+
+ for (;;) {
+ int workAmount = this.workAmount.getPlain();
+ DataBuffer value;
+
+ while((value = queue.poll()) != null) {
+ discard(value);
+ }
+
+ if (this.workAmount.weakCompareAndSetPlain(workAmount, Integer.MIN_VALUE)) {
+ return;
+ }
+ }
+ }
+
+ void discard(@Nullable DataBuffer value) {
+ DataBufferUtils.release(value);
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (this.closed) {
+ return;
+ }
+
+ this.closed = true;
+
+ if (!this.lock.tryLock()) {
+ if (addWork() == 0) {
+ resume();
+ }
+ return;
+ }
+
+ try {
+ this.s.cancel();
+ cleanAndFinalize();
+ }
+ finally {
+ this.lock.unlock();
+ }
+ }
+
+ private void await() {
+ Thread toUnpark = Thread.currentThread();
+
+ while (true) {
+ Object current = this.parkedThread.get();
+ if (current == READY) {
+ break;
+ }
+
+ if (current != null && current != toUnpark) {
+ throw new IllegalStateException("Only one (Virtual)Thread can await!");
+ }
+
+ if (parkedThread.compareAndSet( null, toUnpark)) {
+ LockSupport.park();
+ // we don't just break here because park() can wake up spuriously
+ // if we got a proper resume, get() == READY and the loop will quit above
+ }
+ }
+ // clear the resume indicator so that the next await call will park without a resume()
+ this.parkedThread.lazySet(null);
+ }
+
+ private void resume() {
+ if (this.parkedThread != READY) {
+ Object old = parkedThread.getAndSet(READY);
+ if (old != READY) {
+ LockSupport.unpark((Thread)old);
+ }
+ }
+ }
+
+
+}
diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java
index 9ea04e339c6..d0fe3c54466 100644
--- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java
+++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java
@@ -17,6 +17,7 @@
package org.springframework.core.io.buffer;
import java.io.IOException;
+import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.net.URI;
@@ -27,15 +28,18 @@ import java.nio.channels.FileChannel;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SeekableByteChannel;
import java.nio.channels.WritableByteChannel;
+import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.time.Duration;
+import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadLocalRandom;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
@@ -688,6 +692,189 @@ class DataBufferUtilsTests extends AbstractDataBufferAllocatingTests {
latch.await();
}
+
+
+ @ParameterizedDataBufferAllocatingTest
+ void inputStreamSubscriberChunkSize(DataBufferFactory bufferFactory) {
+ genericInputStreamSubscriberTest(bufferFactory, 3, 3, 64, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz"));
+ }
+
+ @ParameterizedDataBufferAllocatingTest
+ void inputStreamSubscriberChunkSize2(DataBufferFactory bufferFactory) {
+ genericInputStreamSubscriberTest(bufferFactory, 3, 3, 1, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz"));
+ }
+
+ @ParameterizedDataBufferAllocatingTest
+ void inputStreamSubscriberChunkSize3(DataBufferFactory bufferFactory) {
+ genericInputStreamSubscriberTest(bufferFactory, 3, 12, 1, List.of("foo", "bar", "baz"), List.of("foobarbaz"));
+ }
+
+ @ParameterizedDataBufferAllocatingTest
+ void inputStreamSubscriberChunkSize4(DataBufferFactory bufferFactory) {
+ genericInputStreamSubscriberTest(bufferFactory, 3, 1, 1, List.of("foo", "bar", "baz"), List.of("f", "o", "o", "b", "a", "r", "b", "a", "z"));
+ }
+
+ @ParameterizedDataBufferAllocatingTest
+ void inputStreamSubscriberChunkSize5(DataBufferFactory bufferFactory) {
+ genericInputStreamSubscriberTest(bufferFactory, 3, 2, 1, List.of("foo", "bar", "baz"), List.of("fo", "ob", "ar", "ba", "z"));
+ }
+
+ @ParameterizedDataBufferAllocatingTest
+ void inputStreamSubscriberChunkSize6(DataBufferFactory bufferFactory) {
+ genericInputStreamSubscriberTest(bufferFactory, 1, 3, 1, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz"));
+ }
+
+ @ParameterizedDataBufferAllocatingTest
+ void inputStreamSubscriberChunkSize7(DataBufferFactory bufferFactory) {
+ genericInputStreamSubscriberTest(bufferFactory, 1, 3, 64, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz"));
+ }
+
+ void genericInputStreamSubscriberTest(DataBufferFactory bufferFactory, int writeChunkSize, int readChunkSize, int bufferSize, List input, List expectedOutput) {
+ super.bufferFactory = bufferFactory;
+
+ Publisher publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
+ try {
+ for (String word : input) {
+ outputStream.write(word.getBytes(StandardCharsets.UTF_8));
+ }
+ }
+ catch (IOException ex) {
+ fail(ex.getMessage(), ex);
+ }
+ }, super.bufferFactory, Executors.newSingleThreadExecutor(), writeChunkSize);
+
+
+
+ byte[] chunk = new byte[readChunkSize];
+ ArrayList words = new ArrayList<>();
+
+ try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, bufferSize)) {
+ int read;
+ while((read = inputStream.read(chunk)) > -1) {
+ String word = new String(chunk, 0, read, StandardCharsets.UTF_8);
+ words.add(word);
+ }
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ assertThat(words).containsExactlyElementsOf(expectedOutput);
+ }
+
+ @ParameterizedDataBufferAllocatingTest
+ void inputStreamSubscriberError(DataBufferFactory bufferFactory) throws InterruptedException {
+ super.bufferFactory = bufferFactory;
+
+ var input = List.of("foo ", "bar ", "baz");
+
+ Publisher publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
+ try {
+ for (String word : input) {
+ outputStream.write(word.getBytes(StandardCharsets.UTF_8));
+ }
+ throw new RuntimeException("boom");
+ }
+ catch (IOException ex) {
+ fail(ex.getMessage(), ex);
+ }
+ }, super.bufferFactory, Executors.newSingleThreadExecutor(), 1);
+
+
+ RuntimeException error = null;
+ byte[] chunk = new byte[4];
+ ArrayList words = new ArrayList<>();
+
+ try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, 1)) {
+ int read;
+ while((read = inputStream.read(chunk)) > -1) {
+ String word = new String(chunk, 0, read, StandardCharsets.UTF_8);
+ words.add(word);
+ }
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ catch (RuntimeException e) {
+ error = e;
+ }
+ assertThat(words).containsExactlyElementsOf(List.of("foo ", "bar ", "baz"));
+ assertThat(error).hasMessage("boom");
+ }
+
+ @ParameterizedDataBufferAllocatingTest
+ void inputStreamSubscriberMixedReadMode(DataBufferFactory bufferFactory) throws InterruptedException {
+ super.bufferFactory = bufferFactory;
+
+ var input = List.of("foo ", "bar ", "baz");
+
+ Publisher publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
+ try {
+ for (String word : input) {
+ outputStream.write(word.getBytes(StandardCharsets.UTF_8));
+ }
+ }
+ catch (IOException ex) {
+ fail(ex.getMessage(), ex);
+ }
+ }, super.bufferFactory, Executors.newSingleThreadExecutor(), 1);
+
+
+ byte[] chunk = new byte[3];
+ ArrayList words = new ArrayList<>();
+
+ try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, 1)) {
+ words.add(new String(chunk,0, inputStream.read(chunk), StandardCharsets.UTF_8));
+ assertThat(inputStream.read()).isEqualTo(' ' & 0xFF);
+ words.add(new String(chunk,0, inputStream.read(chunk), StandardCharsets.UTF_8));
+ assertThat(inputStream.read()).isEqualTo(' ' & 0xFF);
+ words.add(new String(chunk,0, inputStream.read(chunk), StandardCharsets.UTF_8));
+ assertThat(inputStream.read()).isEqualTo(-1);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ assertThat(words).containsExactlyElementsOf(List.of("foo", "bar", "baz"));
+ }
+
+ @ParameterizedDataBufferAllocatingTest
+ void inputStreamSubscriberClose(DataBufferFactory bufferFactory) throws InterruptedException {
+ for (int i = 1; i < 100; i++) {
+ CountDownLatch latch = new CountDownLatch(1);
+ super.bufferFactory = bufferFactory;
+
+ var input = List.of("foo", "bar", "baz");
+
+ Publisher publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
+ try {
+ assertThatIOException()
+ .isThrownBy(() -> {
+ for (String word : input) {
+ outputStream.write(word.getBytes(StandardCharsets.UTF_8));
+ outputStream.flush();
+ }
+ })
+ .withMessage("Subscription has been terminated");
+ } finally {
+ latch.countDown();
+ }
+ }, super.bufferFactory, Executors.newSingleThreadExecutor(), 1);
+
+
+ byte[] chunk = new byte[3];
+ ArrayList words = new ArrayList<>();
+
+ try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, ThreadLocalRandom.current().nextInt(1, 4))) {
+ inputStream.read(chunk);
+ String word = new String(chunk, StandardCharsets.UTF_8);
+ words.add(word);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ assertThat(words).containsExactlyElementsOf(List.of("foo"));
+ latch.await();
+ }
+ }
+
@ParameterizedDataBufferAllocatingTest
void readAndWriteByteChannel(DataBufferFactory bufferFactory) throws Exception {
super.bufferFactory = bufferFactory;
diff --git a/spring-web/src/main/java/org/springframework/http/client/InputStreamSubscriber.java b/spring-web/src/main/java/org/springframework/http/client/InputStreamSubscriber.java
new file mode 100644
index 00000000000..606527044fe
--- /dev/null
+++ b/spring-web/src/main/java/org/springframework/http/client/InputStreamSubscriber.java
@@ -0,0 +1,405 @@
+package org.springframework.http.client;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.reactivestreams.Publisher;
+import org.reactivestreams.Subscriber;
+import org.reactivestreams.Subscription;
+import org.springframework.core.io.buffer.DataBuffer;
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
+import reactor.core.Exceptions;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ConcurrentModificationException;
+import java.util.Objects;
+import java.util.Queue;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.Flow;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.locks.LockSupport;
+import java.util.concurrent.locks.ReentrantLock;
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+/**
+ * Bridges between {@link Flow.Publisher Flow.Publisher<T>} and {@link InputStream}.
+ *
+ * Note that this class has a near duplicate in
+ * {@link org.springframework.core.io.buffer.InputStreamSubscriber}.
+ *
+ * @author Oleh Dokuka
+ * @since 6.1
+ */
+final class InputStreamSubscriber extends InputStream implements Flow.Subscriber {
+
+ private static final Log logger = LogFactory.getLog(InputStreamSubscriber.class);
+
+ static final Object READY = new Object();
+ static final byte[] DONE = new byte[0];
+ static final byte[] CLOSED = new byte[0];
+
+ final int prefetch;
+ final int limit;
+ final Function mapper;
+ final Consumer onDiscardHandler;
+ final ReentrantLock lock;
+ final Queue queue;
+
+ final AtomicReference parkedThread = new AtomicReference<>();
+ final AtomicInteger workAmount = new AtomicInteger();
+
+ volatile boolean closed;
+ int consumed;
+
+ @Nullable
+ byte[] available;
+ int position;
+
+ @Nullable
+ Flow.Subscription s;
+ boolean done;
+ @Nullable
+ Throwable error;
+
+ private InputStreamSubscriber(Function mapper, Consumer onDiscardHandler, int prefetch) {
+ this.prefetch = prefetch;
+ this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : prefetch - (prefetch >> 2);
+ this.mapper = mapper;
+ this.onDiscardHandler = onDiscardHandler;
+ this.queue = new ArrayBlockingQueue<>(prefetch);
+ this.lock = new ReentrantLock(false);
+ }
+
+ /**
+ * Subscribes to given {@link Publisher} and returns subscription
+ * as {@link InputStream} that allows reading all propagated {@link DataBuffer} messages via its imperative API.
+ * Given the {@link InputStream} implementation buffers messages as per configuration.
+ * The returned {@link InputStream} is considered terminated when the given {@link Publisher} signaled one of the
+ * terminal signal ({@link Subscriber#onComplete() or {@link Subscriber#onError(Throwable)}})
+ * and all the stored {@link DataBuffer} polled from the internal buffer.
+ * The returned {@link InputStream} will call {@link Subscription#cancel()} and release all stored {@link DataBuffer}
+ * when {@link InputStream#close()} is called.
+ *
+ * Note: The implementation of the returned {@link InputStream} disallow concurrent call on
+ * any of the {@link InputStream#read} methods
+ *
+ * Note: {@link Subscription#request(long)} happens eagerly for the first time upon subscription
+ * and then repeats every time {@code bufferSize - (bufferSize >> 2)} consumed
+ *
+ * @param publisher the source of {@link DataBuffer} which should be represented as an {@link InputStream}
+ * @param mapper function to transform <T> element to {@code byte[]}. Note, <T> should be released during the mapping if needed.
+ * @param onDiscardHandler <T> element consumer if returned {@link InputStream} is closed prematurely.
+ * @param bufferSize the maximum amount of <T> elements prefetched in advance and stored inside {@link InputStream}
+ * @return an {@link InputStream} instance representing given {@link Publisher} messages
+ */
+ public static InputStream subscribeTo(Flow.Publisher publisher, Function mapper, Consumer onDiscardHandler, int bufferSize) {
+
+ Assert.notNull(publisher, "Flow.Publisher must not be null");
+ Assert.notNull(mapper, "mapper must not be null");
+ Assert.notNull(onDiscardHandler, "onDiscardHandler must not be null");
+ Assert.isTrue(bufferSize > 0, "bufferSize must be greater than 0");
+
+ InputStreamSubscriber iss = new InputStreamSubscriber<>(mapper, onDiscardHandler, bufferSize);
+ publisher.subscribe(iss);
+ return iss;
+ }
+
+ @Override
+ public void onSubscribe(Flow.Subscription subscription) {
+ if (this.s != null) {
+ subscription.cancel();
+ return;
+ }
+
+ this.s = subscription;
+ subscription.request(prefetch == Integer.MAX_VALUE ? Long.MAX_VALUE : prefetch);
+ }
+
+ @Override
+ public void onNext(T t) {
+ Assert.notNull(t, "T value must not be null");
+
+ if (this.done) {
+ discard(t);
+ return;
+ }
+
+ if (!queue.offer(t)) {
+ discard(t);
+ error = new RuntimeException("Buffer overflow");
+ done = true;
+ }
+
+ int previousWorkState = addWork();
+ if (previousWorkState == Integer.MIN_VALUE) {
+ T value = queue.poll();
+ if (value != null) {
+ discard(value);
+ }
+ return;
+ }
+
+ if (previousWorkState == 0) {
+ resume();
+ }
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ if (this.done) {
+ return;
+ }
+ this.error = throwable;
+ this.done = true;
+
+ if (addWork() == 0) {
+ resume();
+ }
+ }
+
+ @Override
+ public void onComplete() {
+ if (this.done) {
+ return;
+ }
+
+ this.done = true;
+
+ if (addWork() == 0) {
+ resume();
+ }
+ }
+
+ int addWork() {
+ for (;;) {
+ int produced = this.workAmount.getPlain();
+
+ if (produced == Integer.MIN_VALUE) {
+ return Integer.MIN_VALUE;
+ }
+
+ int nextProduced = produced == Integer.MAX_VALUE ? 1 : produced + 1;
+
+
+ if (workAmount.weakCompareAndSetRelease(produced, nextProduced)) {
+ return produced;
+ }
+ }
+ }
+
+ @Override
+ public int read() throws IOException {
+ if (!lock.tryLock()) {
+ if (this.closed) {
+ return -1;
+ }
+ throw new ConcurrentModificationException("concurrent access is disallowed");
+ }
+
+ try {
+ byte[] bytes = getBytesOrAwait();
+
+ if (bytes == DONE) {
+ this.closed = true;
+ cleanAndFinalize();
+ if (this.error == null) {
+ return -1;
+ }
+ else {
+ throw Exceptions.propagate(error);
+ }
+ } else if (bytes == CLOSED) {
+ cleanAndFinalize();
+ return -1;
+ }
+
+ return bytes[this.position++] & 0xFF;
+ }
+ catch (Throwable t) {
+ this.closed = true;
+ this.s.cancel();
+ cleanAndFinalize();
+ throw Exceptions.propagate(t);
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ Objects.checkFromIndexSize(off, len, b.length);
+ if (len == 0) {
+ return 0;
+ }
+
+ if (!lock.tryLock()) {
+ if (this.closed) {
+ return -1;
+ }
+ throw new ConcurrentModificationException("concurrent access is disallowed");
+ }
+
+ try {
+ for (int j = 0; j < len;) {
+ byte[] bytes = getBytesOrAwait();
+
+ if (bytes == DONE) {
+ this.closed = true;
+ cleanAndFinalize();
+ if (this.error == null) {
+ return j == 0 ? -1 : j;
+ }
+ else {
+ throw Exceptions.propagate(error);
+ }
+ } else if (bytes == CLOSED) {
+ this.s.cancel();
+ cleanAndFinalize();
+ return -1;
+ }
+
+ int i = this.position;
+ for (; i < bytes.length && j < len; i++, j++) {
+ b[off + j] = bytes[i];
+ }
+ this.position = i;
+ }
+
+ return len;
+ }
+ catch (Throwable t) {
+ this.closed = true;
+ this.s.cancel();
+ cleanAndFinalize();
+ throw Exceptions.propagate(t);
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ byte[] getBytesOrAwait() {
+ if (this.available == null || this.available.length - this.position == 0) {
+ this.available = null;
+
+ int actualWorkAmount = this.workAmount.getAcquire();
+ for (;;) {
+ if (this.closed) {
+ return CLOSED;
+ }
+
+ boolean d = this.done;
+ T t = this.queue.poll();
+ if (t != null) {
+ int consumed = ++this.consumed;
+ this.position = 0;
+ this.available = Objects.requireNonNull(this.mapper.apply(t));
+ if (consumed == this.limit) {
+ this.consumed = 0;
+ this.s.request(this.limit);
+ }
+ break;
+ }
+
+ if (d) {
+ return DONE;
+ }
+
+ actualWorkAmount = workAmount.addAndGet(-actualWorkAmount);
+ if (actualWorkAmount == 0) {
+ await();
+ }
+ }
+ }
+
+ return this.available;
+ }
+
+ void cleanAndFinalize() {
+ this.available = null;
+
+ for (;;) {
+ int workAmount = this.workAmount.getPlain();
+ T value;
+
+ while((value = queue.poll()) != null) {
+ discard(value);
+ }
+
+ if (this.workAmount.weakCompareAndSetPlain(workAmount, Integer.MIN_VALUE)) {
+ return;
+ }
+ }
+ }
+
+ void discard(T value) {
+ try {
+ this.onDiscardHandler.accept(value);
+ } catch (Throwable t) {
+ if (logger.isDebugEnabled()) {
+ logger.debug("Failed to release " + value.getClass().getSimpleName() + ": " + value, t);
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (this.closed) {
+ return;
+ }
+
+ this.closed = true;
+
+ if (!this.lock.tryLock()) {
+ if (addWork() == 0) {
+ resume();
+ }
+ return;
+ }
+
+ try {
+ this.s.cancel();
+ cleanAndFinalize();
+ }
+ finally {
+ this.lock.unlock();
+ }
+ }
+
+ private void await() {
+ Thread toUnpark = Thread.currentThread();
+
+ while (true) {
+ Object current = this.parkedThread.get();
+ if (current == READY) {
+ break;
+ }
+
+ if (current != null && current != toUnpark) {
+ throw new IllegalStateException("Only one (Virtual)Thread can await!");
+ }
+
+ if (parkedThread.compareAndSet( null, toUnpark)) {
+ LockSupport.park();
+ // we don't just break here because park() can wake up spuriously
+ // if we got a proper resume, get() == READY and the loop will quit above
+ }
+ }
+ // clear the resume indicator so that the next await call will park without a resume()
+ this.parkedThread.lazySet(null);
+ }
+
+ private void resume() {
+ if (this.parkedThread != READY) {
+ Object old = parkedThread.getAndSet(READY);
+ if (old != READY) {
+ LockSupport.unpark((Thread)old);
+ }
+ }
+ }
+
+}
diff --git a/spring-web/src/test/java/org/springframework/http/client/InputStreamSubscriberTests.java b/spring-web/src/test/java/org/springframework/http/client/InputStreamSubscriberTests.java
new file mode 100644
index 00000000000..9dd635dffba
--- /dev/null
+++ b/spring-web/src/test/java/org/springframework/http/client/InputStreamSubscriberTests.java
@@ -0,0 +1,259 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.http.client;
+
+import org.junit.jupiter.api.Test;
+import org.reactivestreams.FlowAdapters;
+import reactor.core.publisher.Flux;
+import reactor.test.StepVerifier;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStreamWriter;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Flow;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIOException;
+
+/**
+ * @author Arjen Poutsma
+ * @author Oleh Dokuka
+ */
+class InputStreamSubscriberTests {
+
+ private static final byte[] FOO = "foo".getBytes(StandardCharsets.UTF_8);
+
+ private static final byte[] BAR = "bar".getBytes(StandardCharsets.UTF_8);
+
+ private static final byte[] BAZ = "baz".getBytes(StandardCharsets.UTF_8);
+
+
+ private final Executor executor = Executors.newSingleThreadExecutor();
+
+ private final OutputStreamPublisher.ByteMapper byteMapper =
+ new OutputStreamPublisher.ByteMapper<>() {
+ @Override
+ public byte[] map(int b) {
+ return new byte[]{(byte) b};
+ }
+
+ @Override
+ public byte[] map(byte[] b, int off, int len) {
+ byte[] result = new byte[len];
+ System.arraycopy(b, off, result, 0, len);
+ return result;
+ }
+ };
+
+
+ @Test
+ void basic() {
+ Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> {
+ outputStream.write(FOO);
+ outputStream.write(BAR);
+ outputStream.write(BAZ);
+ }, this.byteMapper, this.executor);
+ Flux flux = toString(flowPublisher);
+
+ StepVerifier.create(flux)
+ .assertNext(s -> assertThat(s).isEqualTo("foobarbaz"))
+ .verifyComplete();
+ }
+
+ @Test
+ void flush() {
+ Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> {
+ outputStream.write(FOO);
+ outputStream.flush();
+ outputStream.write(BAR);
+ outputStream.flush();
+ outputStream.write(BAZ);
+ outputStream.flush();
+ }, this.byteMapper, this.executor);
+ Flux flux = toString(flowPublisher);
+
+ try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), (ignore) -> {}, 1)) {
+ byte[] chunk = new byte[3];
+
+ assertThat(is.read(chunk)).isEqualTo(3);
+ assertThat(chunk).containsExactly(FOO);
+ assertThat(is.read(chunk)).isEqualTo(3);
+ assertThat(chunk).containsExactly(BAR);
+ assertThat(is.read(chunk)).isEqualTo(3);
+ assertThat(chunk).containsExactly(BAZ);
+ assertThat(is.read(chunk)).isEqualTo(-1);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ void chunkSize() {
+ Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> {
+ outputStream.write(FOO);
+ outputStream.write(BAR);
+ outputStream.write(BAZ);
+ }, this.byteMapper, this.executor, 2);
+ Flux flux = toString(flowPublisher);
+
+ try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), (ignore) -> {}, 1)) {
+ StringBuilder stringBuilder = new StringBuilder();
+ byte[] chunk = new byte[3];
+
+
+ stringBuilder
+ .append(new String(new byte[]{(byte)is.read()}, StandardCharsets.UTF_8));
+ assertThat(is.read(chunk)).isEqualTo(3);
+ stringBuilder
+ .append(new String(chunk, StandardCharsets.UTF_8));
+ assertThat(is.read(chunk)).isEqualTo(3);
+ stringBuilder
+ .append(new String(chunk, StandardCharsets.UTF_8));
+ assertThat(is.read(chunk)).isEqualTo(2);
+ stringBuilder
+ .append(new String(chunk,0, 2, StandardCharsets.UTF_8));
+ assertThat(is.read()).isEqualTo(-1);
+
+ assertThat(stringBuilder.toString()).isEqualTo("foobarbaz");
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ void cancel() throws InterruptedException {
+ CountDownLatch latch = new CountDownLatch(1);
+
+ Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> {
+ assertThatIOException()
+ .isThrownBy(() -> {
+ outputStream.write(FOO);
+ outputStream.flush();
+ outputStream.write(BAR);
+ outputStream.flush();
+ outputStream.write(BAZ);
+ outputStream.flush();
+ })
+ .withMessage("Subscription has been terminated");
+ latch.countDown();
+
+ }, this.byteMapper, this.executor);
+ Flux flux = toString(flowPublisher);
+ List discarded = new ArrayList<>();
+
+ try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), discarded::add, 1)) {
+ byte[] chunk = new byte[3];
+
+ assertThat(is.read(chunk)).isEqualTo(3);
+ assertThat(chunk).containsExactly(FOO);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ latch.await();
+
+ assertThat(discarded).containsExactly("bar");
+ }
+
+ @Test
+ void closed() throws InterruptedException {
+ CountDownLatch latch = new CountDownLatch(1);
+
+ Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> {
+ OutputStreamWriter writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8);
+ writer.write("foo");
+ writer.close();
+ assertThatIOException().isThrownBy(() -> writer.write("bar"))
+ .withMessage("Stream closed");
+ latch.countDown();
+ }, this.byteMapper, this.executor);
+ Flux flux = toString(flowPublisher);
+
+ try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), ig -> {}, 1)) {
+ byte[] chunk = new byte[3];
+
+ assertThat(is.read(chunk)).isEqualTo(3);
+ assertThat(chunk).containsExactly(FOO);
+
+ assertThat(is.read(chunk)).isEqualTo(-1);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ latch.await();
+ }
+
+ @Test
+ void mapperThrowsException() throws InterruptedException {
+ CountDownLatch latch = new CountDownLatch(1);
+
+ Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> {
+ outputStream.write(FOO);
+ outputStream.flush();
+ assertThatIOException().isThrownBy(() -> {
+ outputStream.write(BAR);
+ outputStream.flush();
+ }).withMessage("Subscription has been terminated");
+ latch.countDown();
+ }, this.byteMapper, this.executor);
+ Throwable ex = null;
+
+ StringBuilder stringBuilder = new StringBuilder();
+ try (InputStream is = InputStreamSubscriber.subscribeTo(flowPublisher, (s) -> {
+ throw new NullPointerException("boom");
+ }, ig -> {}, 1)) {
+ byte[] chunk = new byte[3];
+
+ stringBuilder
+ .append(new String(new byte[]{(byte)is.read()}, StandardCharsets.UTF_8));
+ assertThat(is.read(chunk)).isEqualTo(3);
+ stringBuilder
+ .append(new String(chunk, StandardCharsets.UTF_8));
+ assertThat(is.read(chunk)).isEqualTo(3);
+ stringBuilder
+ .append(new String(chunk, StandardCharsets.UTF_8));
+ assertThat(is.read(chunk)).isEqualTo(2);
+ stringBuilder
+ .append(new String(chunk,0, 2, StandardCharsets.UTF_8));
+ assertThat(is.read()).isEqualTo(-1);
+ }
+ catch (Throwable e) {
+ ex = e;
+ }
+
+ latch.await();
+
+ assertThat(stringBuilder.toString()).isEqualTo("");
+ assertThat(ex).hasMessage("boom");
+ }
+
+ private static Flux toString(Flow.Publisher flowPublisher) {
+ return Flux.from(FlowAdapters.toPublisher(flowPublisher))
+ .map(bytes -> new String(bytes, StandardCharsets.UTF_8));
+ }
+
+}