Support Publisher to InputStream conversion

See gh-31677
This commit is contained in:
OlegDokuka 2023-11-24 16:00:35 +02:00 committed by rstoyanchev
parent de2c10abcd
commit 37622a7f90
5 changed files with 1235 additions and 0 deletions

View File

@ -456,6 +456,35 @@ public abstract class DataBufferUtils {
consumer::accept, new DataBufferMapper(bufferFactory), executor, chunkSize);
}
/**
* Subscribes to given {@link Publisher} and returns subscription
* as {@link InputStream} that allows reading all propagated {@link DataBuffer} messages via its imperative API.
* Given the {@link InputStream} implementation buffers messages as per configuration.
* The returned {@link InputStream} is considered terminated when the given {@link Publisher} signaled one of the
* terminal signal ({@link Subscriber#onComplete() or {@link Subscriber#onError(Throwable)}})
* and all the stored {@link DataBuffer} polled from the internal buffer.
* The returned {@link InputStream} will call {@link Subscription#cancel()} and release all stored {@link DataBuffer}
* when {@link InputStream#close()} is called.
* <p>
* Note: The implementation of the returned {@link InputStream} disallow concurrent call on
* any of the {@link InputStream#read} methods
* <p>
* Note: {@link Subscription#request(long)} happens eagerly for the first time upon subscription
* and then repeats every time {@code bufferSize - (bufferSize >> 2)} consumed
*
* @param publisher the source of {@link DataBuffer} which should be represented as an {@link InputStream}
* @param bufferSize the maximum amount of {@link DataBuffer} prefetched in advance and stored inside {@link InputStream}
* @return an {@link InputStream} instance representing given {@link Publisher} messages
*/
public static <T extends DataBuffer> InputStream subscribeAsInputStream(Publisher<T> publisher, int bufferSize) {
Assert.notNull(publisher, "Publisher must not be null");
Assert.isTrue(bufferSize > 0, "Buffer size must be > 0");
InputStreamSubscriber inputStreamSubscriber = new InputStreamSubscriber(bufferSize);
publisher.subscribe(inputStreamSubscriber);
return inputStreamSubscriber;
}
//---------------------------------------------------------------------
// Various

View File

@ -0,0 +1,355 @@
package org.springframework.core.io.buffer;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.springframework.lang.Nullable;
import reactor.core.Exceptions;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ConcurrentModificationException;
import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;
/**
* Bridges between {@link Publisher Publisher&lt;DataBuffer&gt;} and {@link InputStream}.
*
* <p>Note that this class has a near duplicate in
* {@link org.springframework.http.client.InputStreamSubscriber}.
*
* @author Oleh Dokuka
* @since 6.1
*/
final class InputStreamSubscriber extends InputStream implements Subscriber<DataBuffer> {
static final Object READY = new Object();
static final DataBuffer DONE = DefaultDataBuffer.fromEmptyByteBuffer(DefaultDataBufferFactory.sharedInstance, ByteBuffer.allocate(0));
static final DataBuffer CLOSED = DefaultDataBuffer.fromEmptyByteBuffer(DefaultDataBufferFactory.sharedInstance, ByteBuffer.allocate(0));
final int prefetch;
final int limit;
final ReentrantLock lock;
final Queue<DataBuffer> queue;
final AtomicReference<Object> parkedThread = new AtomicReference<>();
final AtomicInteger workAmount = new AtomicInteger();
volatile boolean closed;
int consumed;
@Nullable
DataBuffer available;
@Nullable
Subscription s;
boolean done;
@Nullable
Throwable error;
InputStreamSubscriber(int prefetch) {
this.prefetch = prefetch;
this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : prefetch - (prefetch >> 2);
this.queue = new ArrayBlockingQueue<>(prefetch);
this.lock = new ReentrantLock(false);
}
@Override
public void onSubscribe(Subscription subscription) {
if (this.s != null) {
subscription.cancel();
return;
}
this.s = subscription;
subscription.request(prefetch == Integer.MAX_VALUE ? Long.MAX_VALUE : prefetch);
}
@Override
public void onNext(DataBuffer t) {
if (this.done) {
discard(t);
return;
}
if (!queue.offer(t)) {
discard(t);
error = new RuntimeException("Buffer overflow");
done = true;
}
int previousWorkState = addWork();
if (previousWorkState == Integer.MIN_VALUE) {
DataBuffer value = queue.poll();
if (value != null) {
discard(value);
}
return;
}
if (previousWorkState == 0) {
resume();
}
}
@Override
public void onError(Throwable throwable) {
if (this.done) {
return;
}
this.error = throwable;
this.done = true;
if (addWork() == 0) {
resume();
}
}
@Override
public void onComplete() {
if (this.done) {
return;
}
this.done = true;
if (addWork() == 0) {
resume();
}
}
int addWork() {
for (;;) {
int produced = this.workAmount.getPlain();
if (produced == Integer.MIN_VALUE) {
return Integer.MIN_VALUE;
}
int nextProduced = produced == Integer.MAX_VALUE ? 1 : produced + 1;
if (workAmount.weakCompareAndSetRelease(produced, nextProduced)) {
return produced;
}
}
}
@Override
public int read() throws IOException {
if (!lock.tryLock()) {
if (this.closed) {
return -1;
}
throw new ConcurrentModificationException("concurrent access is disallowed");
}
try {
DataBuffer bytes = getBytesOrAwait();
if (bytes == DONE) {
this.closed = true;
cleanAndFinalize();
if (this.error == null) {
return -1;
}
else {
throw Exceptions.propagate(error);
}
} else if (bytes == CLOSED) {
cleanAndFinalize();
return -1;
}
return bytes.read() & 0xFF;
}
catch (Throwable t) {
this.closed = true;
this.s.cancel();
cleanAndFinalize();
throw Exceptions.propagate(t);
}
finally {
lock.unlock();
}
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
Objects.checkFromIndexSize(off, len, b.length);
if (len == 0) {
return 0;
}
if (!lock.tryLock()) {
if (this.closed) {
return -1;
}
throw new ConcurrentModificationException("concurrent access is disallowed");
}
try {
for (int j = 0; j < len;) {
DataBuffer bytes = getBytesOrAwait();
if (bytes == DONE) {
cleanAndFinalize();
if (this.error == null) {
this.closed = true;
return j == 0 ? -1 : j;
}
else {
if (j == 0) {
this.closed = true;
throw Exceptions.propagate(error);
}
return j;
}
} else if (bytes == CLOSED) {
this.s.cancel();
cleanAndFinalize();
return -1;
}
int initialReadPosition = bytes.readPosition();
bytes.read(b, off + j, Math.min(len - j, bytes.readableByteCount()));
j += bytes.readPosition() - initialReadPosition;
}
return len;
}
catch (Throwable t) {
this.closed = true;
this.s.cancel();
cleanAndFinalize();
throw Exceptions.propagate(t);
}
finally {
lock.unlock();
}
}
DataBuffer getBytesOrAwait() {
if (this.available == null || this.available.readableByteCount() == 0) {
discard(this.available);
this.available = null;
int actualWorkAmount = this.workAmount.getAcquire();
for (;;) {
if (this.closed) {
return CLOSED;
}
boolean d = this.done;
DataBuffer t = this.queue.poll();
if (t != null) {
int consumed = ++this.consumed;
this.available = t;
if (consumed == this.limit) {
this.consumed = 0;
this.s.request(this.limit);
}
break;
}
if (d) {
return DONE;
}
actualWorkAmount = workAmount.addAndGet(-actualWorkAmount);
if (actualWorkAmount == 0) {
await();
}
}
}
return this.available;
}
void cleanAndFinalize() {
discard(this.available);
this.available = null;
for (;;) {
int workAmount = this.workAmount.getPlain();
DataBuffer value;
while((value = queue.poll()) != null) {
discard(value);
}
if (this.workAmount.weakCompareAndSetPlain(workAmount, Integer.MIN_VALUE)) {
return;
}
}
}
void discard(@Nullable DataBuffer value) {
DataBufferUtils.release(value);
}
@Override
public void close() throws IOException {
if (this.closed) {
return;
}
this.closed = true;
if (!this.lock.tryLock()) {
if (addWork() == 0) {
resume();
}
return;
}
try {
this.s.cancel();
cleanAndFinalize();
}
finally {
this.lock.unlock();
}
}
private void await() {
Thread toUnpark = Thread.currentThread();
while (true) {
Object current = this.parkedThread.get();
if (current == READY) {
break;
}
if (current != null && current != toUnpark) {
throw new IllegalStateException("Only one (Virtual)Thread can await!");
}
if (parkedThread.compareAndSet( null, toUnpark)) {
LockSupport.park();
// we don't just break here because park() can wake up spuriously
// if we got a proper resume, get() == READY and the loop will quit above
}
}
// clear the resume indicator so that the next await call will park without a resume()
this.parkedThread.lazySet(null);
}
private void resume() {
if (this.parkedThread != READY) {
Object old = parkedThread.getAndSet(READY);
if (old != READY) {
LockSupport.unpark((Thread)old);
}
}
}
}

View File

@ -17,6 +17,7 @@
package org.springframework.core.io.buffer;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.net.URI;
@ -27,15 +28,18 @@ import java.nio.channels.FileChannel;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SeekableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
@ -688,6 +692,189 @@ class DataBufferUtilsTests extends AbstractDataBufferAllocatingTests {
latch.await();
}
@ParameterizedDataBufferAllocatingTest
void inputStreamSubscriberChunkSize(DataBufferFactory bufferFactory) {
genericInputStreamSubscriberTest(bufferFactory, 3, 3, 64, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz"));
}
@ParameterizedDataBufferAllocatingTest
void inputStreamSubscriberChunkSize2(DataBufferFactory bufferFactory) {
genericInputStreamSubscriberTest(bufferFactory, 3, 3, 1, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz"));
}
@ParameterizedDataBufferAllocatingTest
void inputStreamSubscriberChunkSize3(DataBufferFactory bufferFactory) {
genericInputStreamSubscriberTest(bufferFactory, 3, 12, 1, List.of("foo", "bar", "baz"), List.of("foobarbaz"));
}
@ParameterizedDataBufferAllocatingTest
void inputStreamSubscriberChunkSize4(DataBufferFactory bufferFactory) {
genericInputStreamSubscriberTest(bufferFactory, 3, 1, 1, List.of("foo", "bar", "baz"), List.of("f", "o", "o", "b", "a", "r", "b", "a", "z"));
}
@ParameterizedDataBufferAllocatingTest
void inputStreamSubscriberChunkSize5(DataBufferFactory bufferFactory) {
genericInputStreamSubscriberTest(bufferFactory, 3, 2, 1, List.of("foo", "bar", "baz"), List.of("fo", "ob", "ar", "ba", "z"));
}
@ParameterizedDataBufferAllocatingTest
void inputStreamSubscriberChunkSize6(DataBufferFactory bufferFactory) {
genericInputStreamSubscriberTest(bufferFactory, 1, 3, 1, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz"));
}
@ParameterizedDataBufferAllocatingTest
void inputStreamSubscriberChunkSize7(DataBufferFactory bufferFactory) {
genericInputStreamSubscriberTest(bufferFactory, 1, 3, 64, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz"));
}
void genericInputStreamSubscriberTest(DataBufferFactory bufferFactory, int writeChunkSize, int readChunkSize, int bufferSize, List<String> input, List<String> expectedOutput) {
super.bufferFactory = bufferFactory;
Publisher<DataBuffer> publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
try {
for (String word : input) {
outputStream.write(word.getBytes(StandardCharsets.UTF_8));
}
}
catch (IOException ex) {
fail(ex.getMessage(), ex);
}
}, super.bufferFactory, Executors.newSingleThreadExecutor(), writeChunkSize);
byte[] chunk = new byte[readChunkSize];
ArrayList<String> words = new ArrayList<>();
try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, bufferSize)) {
int read;
while((read = inputStream.read(chunk)) > -1) {
String word = new String(chunk, 0, read, StandardCharsets.UTF_8);
words.add(word);
}
}
catch (IOException e) {
throw new RuntimeException(e);
}
assertThat(words).containsExactlyElementsOf(expectedOutput);
}
@ParameterizedDataBufferAllocatingTest
void inputStreamSubscriberError(DataBufferFactory bufferFactory) throws InterruptedException {
super.bufferFactory = bufferFactory;
var input = List.of("foo ", "bar ", "baz");
Publisher<DataBuffer> publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
try {
for (String word : input) {
outputStream.write(word.getBytes(StandardCharsets.UTF_8));
}
throw new RuntimeException("boom");
}
catch (IOException ex) {
fail(ex.getMessage(), ex);
}
}, super.bufferFactory, Executors.newSingleThreadExecutor(), 1);
RuntimeException error = null;
byte[] chunk = new byte[4];
ArrayList<String> words = new ArrayList<>();
try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, 1)) {
int read;
while((read = inputStream.read(chunk)) > -1) {
String word = new String(chunk, 0, read, StandardCharsets.UTF_8);
words.add(word);
}
}
catch (IOException e) {
throw new RuntimeException(e);
}
catch (RuntimeException e) {
error = e;
}
assertThat(words).containsExactlyElementsOf(List.of("foo ", "bar ", "baz"));
assertThat(error).hasMessage("boom");
}
@ParameterizedDataBufferAllocatingTest
void inputStreamSubscriberMixedReadMode(DataBufferFactory bufferFactory) throws InterruptedException {
super.bufferFactory = bufferFactory;
var input = List.of("foo ", "bar ", "baz");
Publisher<DataBuffer> publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
try {
for (String word : input) {
outputStream.write(word.getBytes(StandardCharsets.UTF_8));
}
}
catch (IOException ex) {
fail(ex.getMessage(), ex);
}
}, super.bufferFactory, Executors.newSingleThreadExecutor(), 1);
byte[] chunk = new byte[3];
ArrayList<String> words = new ArrayList<>();
try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, 1)) {
words.add(new String(chunk,0, inputStream.read(chunk), StandardCharsets.UTF_8));
assertThat(inputStream.read()).isEqualTo(' ' & 0xFF);
words.add(new String(chunk,0, inputStream.read(chunk), StandardCharsets.UTF_8));
assertThat(inputStream.read()).isEqualTo(' ' & 0xFF);
words.add(new String(chunk,0, inputStream.read(chunk), StandardCharsets.UTF_8));
assertThat(inputStream.read()).isEqualTo(-1);
}
catch (IOException e) {
throw new RuntimeException(e);
}
assertThat(words).containsExactlyElementsOf(List.of("foo", "bar", "baz"));
}
@ParameterizedDataBufferAllocatingTest
void inputStreamSubscriberClose(DataBufferFactory bufferFactory) throws InterruptedException {
for (int i = 1; i < 100; i++) {
CountDownLatch latch = new CountDownLatch(1);
super.bufferFactory = bufferFactory;
var input = List.of("foo", "bar", "baz");
Publisher<DataBuffer> publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
try {
assertThatIOException()
.isThrownBy(() -> {
for (String word : input) {
outputStream.write(word.getBytes(StandardCharsets.UTF_8));
outputStream.flush();
}
})
.withMessage("Subscription has been terminated");
} finally {
latch.countDown();
}
}, super.bufferFactory, Executors.newSingleThreadExecutor(), 1);
byte[] chunk = new byte[3];
ArrayList<String> words = new ArrayList<>();
try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, ThreadLocalRandom.current().nextInt(1, 4))) {
inputStream.read(chunk);
String word = new String(chunk, StandardCharsets.UTF_8);
words.add(word);
} catch (IOException e) {
throw new RuntimeException(e);
}
assertThat(words).containsExactlyElementsOf(List.of("foo"));
latch.await();
}
}
@ParameterizedDataBufferAllocatingTest
void readAndWriteByteChannel(DataBufferFactory bufferFactory) throws Exception {
super.bufferFactory = bufferFactory;

View File

@ -0,0 +1,405 @@
package org.springframework.http.client;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import reactor.core.Exceptions;
import java.io.IOException;
import java.io.InputStream;
import java.util.ConcurrentModificationException;
import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Function;
/**
* Bridges between {@link Flow.Publisher Flow.Publisher&lt;T&gt;} and {@link InputStream}.
*
* <p>Note that this class has a near duplicate in
* {@link org.springframework.core.io.buffer.InputStreamSubscriber}.
*
* @author Oleh Dokuka
* @since 6.1
*/
final class InputStreamSubscriber<T> extends InputStream implements Flow.Subscriber<T> {
private static final Log logger = LogFactory.getLog(InputStreamSubscriber.class);
static final Object READY = new Object();
static final byte[] DONE = new byte[0];
static final byte[] CLOSED = new byte[0];
final int prefetch;
final int limit;
final Function<T, byte[]> mapper;
final Consumer<T> onDiscardHandler;
final ReentrantLock lock;
final Queue<T> queue;
final AtomicReference<Object> parkedThread = new AtomicReference<>();
final AtomicInteger workAmount = new AtomicInteger();
volatile boolean closed;
int consumed;
@Nullable
byte[] available;
int position;
@Nullable
Flow.Subscription s;
boolean done;
@Nullable
Throwable error;
private InputStreamSubscriber(Function<T, byte[]> mapper, Consumer<T> onDiscardHandler, int prefetch) {
this.prefetch = prefetch;
this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : prefetch - (prefetch >> 2);
this.mapper = mapper;
this.onDiscardHandler = onDiscardHandler;
this.queue = new ArrayBlockingQueue<>(prefetch);
this.lock = new ReentrantLock(false);
}
/**
* Subscribes to given {@link Publisher} and returns subscription
* as {@link InputStream} that allows reading all propagated {@link DataBuffer} messages via its imperative API.
* Given the {@link InputStream} implementation buffers messages as per configuration.
* The returned {@link InputStream} is considered terminated when the given {@link Publisher} signaled one of the
* terminal signal ({@link Subscriber#onComplete() or {@link Subscriber#onError(Throwable)}})
* and all the stored {@link DataBuffer} polled from the internal buffer.
* The returned {@link InputStream} will call {@link Subscription#cancel()} and release all stored {@link DataBuffer}
* when {@link InputStream#close()} is called.
* <p>
* Note: The implementation of the returned {@link InputStream} disallow concurrent call on
* any of the {@link InputStream#read} methods
* <p>
* Note: {@link Subscription#request(long)} happens eagerly for the first time upon subscription
* and then repeats every time {@code bufferSize - (bufferSize >> 2)} consumed
*
* @param publisher the source of {@link DataBuffer} which should be represented as an {@link InputStream}
* @param mapper function to transform &lt;T&gt; element to {@code byte[]}. Note, &lt;T&gt; should be released during the mapping if needed.
* @param onDiscardHandler &lt;T&gt; element consumer if returned {@link InputStream} is closed prematurely.
* @param bufferSize the maximum amount of &lt;T&gt; elements prefetched in advance and stored inside {@link InputStream}
* @return an {@link InputStream} instance representing given {@link Publisher} messages
*/
public static <T> InputStream subscribeTo(Flow.Publisher<T> publisher, Function<T, byte[]> mapper, Consumer<T> onDiscardHandler, int bufferSize) {
Assert.notNull(publisher, "Flow.Publisher must not be null");
Assert.notNull(mapper, "mapper must not be null");
Assert.notNull(onDiscardHandler, "onDiscardHandler must not be null");
Assert.isTrue(bufferSize > 0, "bufferSize must be greater than 0");
InputStreamSubscriber<T> iss = new InputStreamSubscriber<>(mapper, onDiscardHandler, bufferSize);
publisher.subscribe(iss);
return iss;
}
@Override
public void onSubscribe(Flow.Subscription subscription) {
if (this.s != null) {
subscription.cancel();
return;
}
this.s = subscription;
subscription.request(prefetch == Integer.MAX_VALUE ? Long.MAX_VALUE : prefetch);
}
@Override
public void onNext(T t) {
Assert.notNull(t, "T value must not be null");
if (this.done) {
discard(t);
return;
}
if (!queue.offer(t)) {
discard(t);
error = new RuntimeException("Buffer overflow");
done = true;
}
int previousWorkState = addWork();
if (previousWorkState == Integer.MIN_VALUE) {
T value = queue.poll();
if (value != null) {
discard(value);
}
return;
}
if (previousWorkState == 0) {
resume();
}
}
@Override
public void onError(Throwable throwable) {
if (this.done) {
return;
}
this.error = throwable;
this.done = true;
if (addWork() == 0) {
resume();
}
}
@Override
public void onComplete() {
if (this.done) {
return;
}
this.done = true;
if (addWork() == 0) {
resume();
}
}
int addWork() {
for (;;) {
int produced = this.workAmount.getPlain();
if (produced == Integer.MIN_VALUE) {
return Integer.MIN_VALUE;
}
int nextProduced = produced == Integer.MAX_VALUE ? 1 : produced + 1;
if (workAmount.weakCompareAndSetRelease(produced, nextProduced)) {
return produced;
}
}
}
@Override
public int read() throws IOException {
if (!lock.tryLock()) {
if (this.closed) {
return -1;
}
throw new ConcurrentModificationException("concurrent access is disallowed");
}
try {
byte[] bytes = getBytesOrAwait();
if (bytes == DONE) {
this.closed = true;
cleanAndFinalize();
if (this.error == null) {
return -1;
}
else {
throw Exceptions.propagate(error);
}
} else if (bytes == CLOSED) {
cleanAndFinalize();
return -1;
}
return bytes[this.position++] & 0xFF;
}
catch (Throwable t) {
this.closed = true;
this.s.cancel();
cleanAndFinalize();
throw Exceptions.propagate(t);
}
finally {
lock.unlock();
}
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
Objects.checkFromIndexSize(off, len, b.length);
if (len == 0) {
return 0;
}
if (!lock.tryLock()) {
if (this.closed) {
return -1;
}
throw new ConcurrentModificationException("concurrent access is disallowed");
}
try {
for (int j = 0; j < len;) {
byte[] bytes = getBytesOrAwait();
if (bytes == DONE) {
this.closed = true;
cleanAndFinalize();
if (this.error == null) {
return j == 0 ? -1 : j;
}
else {
throw Exceptions.propagate(error);
}
} else if (bytes == CLOSED) {
this.s.cancel();
cleanAndFinalize();
return -1;
}
int i = this.position;
for (; i < bytes.length && j < len; i++, j++) {
b[off + j] = bytes[i];
}
this.position = i;
}
return len;
}
catch (Throwable t) {
this.closed = true;
this.s.cancel();
cleanAndFinalize();
throw Exceptions.propagate(t);
}
finally {
lock.unlock();
}
}
byte[] getBytesOrAwait() {
if (this.available == null || this.available.length - this.position == 0) {
this.available = null;
int actualWorkAmount = this.workAmount.getAcquire();
for (;;) {
if (this.closed) {
return CLOSED;
}
boolean d = this.done;
T t = this.queue.poll();
if (t != null) {
int consumed = ++this.consumed;
this.position = 0;
this.available = Objects.requireNonNull(this.mapper.apply(t));
if (consumed == this.limit) {
this.consumed = 0;
this.s.request(this.limit);
}
break;
}
if (d) {
return DONE;
}
actualWorkAmount = workAmount.addAndGet(-actualWorkAmount);
if (actualWorkAmount == 0) {
await();
}
}
}
return this.available;
}
void cleanAndFinalize() {
this.available = null;
for (;;) {
int workAmount = this.workAmount.getPlain();
T value;
while((value = queue.poll()) != null) {
discard(value);
}
if (this.workAmount.weakCompareAndSetPlain(workAmount, Integer.MIN_VALUE)) {
return;
}
}
}
void discard(T value) {
try {
this.onDiscardHandler.accept(value);
} catch (Throwable t) {
if (logger.isDebugEnabled()) {
logger.debug("Failed to release " + value.getClass().getSimpleName() + ": " + value, t);
}
}
}
@Override
public void close() throws IOException {
if (this.closed) {
return;
}
this.closed = true;
if (!this.lock.tryLock()) {
if (addWork() == 0) {
resume();
}
return;
}
try {
this.s.cancel();
cleanAndFinalize();
}
finally {
this.lock.unlock();
}
}
private void await() {
Thread toUnpark = Thread.currentThread();
while (true) {
Object current = this.parkedThread.get();
if (current == READY) {
break;
}
if (current != null && current != toUnpark) {
throw new IllegalStateException("Only one (Virtual)Thread can await!");
}
if (parkedThread.compareAndSet( null, toUnpark)) {
LockSupport.park();
// we don't just break here because park() can wake up spuriously
// if we got a proper resume, get() == READY and the loop will quit above
}
}
// clear the resume indicator so that the next await call will park without a resume()
this.parkedThread.lazySet(null);
}
private void resume() {
if (this.parkedThread != READY) {
Object old = parkedThread.getAndSet(READY);
if (old != READY) {
LockSupport.unpark((Thread)old);
}
}
}
}

View File

@ -0,0 +1,259 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.client;
import org.junit.jupiter.api.Test;
import org.reactivestreams.FlowAdapters;
import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.Flow;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIOException;
/**
* @author Arjen Poutsma
* @author Oleh Dokuka
*/
class InputStreamSubscriberTests {
private static final byte[] FOO = "foo".getBytes(StandardCharsets.UTF_8);
private static final byte[] BAR = "bar".getBytes(StandardCharsets.UTF_8);
private static final byte[] BAZ = "baz".getBytes(StandardCharsets.UTF_8);
private final Executor executor = Executors.newSingleThreadExecutor();
private final OutputStreamPublisher.ByteMapper<byte[]> byteMapper =
new OutputStreamPublisher.ByteMapper<>() {
@Override
public byte[] map(int b) {
return new byte[]{(byte) b};
}
@Override
public byte[] map(byte[] b, int off, int len) {
byte[] result = new byte[len];
System.arraycopy(b, off, result, 0, len);
return result;
}
};
@Test
void basic() {
Flow.Publisher<byte[]> flowPublisher = OutputStreamPublisher.create(outputStream -> {
outputStream.write(FOO);
outputStream.write(BAR);
outputStream.write(BAZ);
}, this.byteMapper, this.executor);
Flux<String> flux = toString(flowPublisher);
StepVerifier.create(flux)
.assertNext(s -> assertThat(s).isEqualTo("foobarbaz"))
.verifyComplete();
}
@Test
void flush() {
Flow.Publisher<byte[]> flowPublisher = OutputStreamPublisher.create(outputStream -> {
outputStream.write(FOO);
outputStream.flush();
outputStream.write(BAR);
outputStream.flush();
outputStream.write(BAZ);
outputStream.flush();
}, this.byteMapper, this.executor);
Flux<String> flux = toString(flowPublisher);
try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), (ignore) -> {}, 1)) {
byte[] chunk = new byte[3];
assertThat(is.read(chunk)).isEqualTo(3);
assertThat(chunk).containsExactly(FOO);
assertThat(is.read(chunk)).isEqualTo(3);
assertThat(chunk).containsExactly(BAR);
assertThat(is.read(chunk)).isEqualTo(3);
assertThat(chunk).containsExactly(BAZ);
assertThat(is.read(chunk)).isEqualTo(-1);
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
@Test
void chunkSize() {
Flow.Publisher<byte[]> flowPublisher = OutputStreamPublisher.create(outputStream -> {
outputStream.write(FOO);
outputStream.write(BAR);
outputStream.write(BAZ);
}, this.byteMapper, this.executor, 2);
Flux<String> flux = toString(flowPublisher);
try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), (ignore) -> {}, 1)) {
StringBuilder stringBuilder = new StringBuilder();
byte[] chunk = new byte[3];
stringBuilder
.append(new String(new byte[]{(byte)is.read()}, StandardCharsets.UTF_8));
assertThat(is.read(chunk)).isEqualTo(3);
stringBuilder
.append(new String(chunk, StandardCharsets.UTF_8));
assertThat(is.read(chunk)).isEqualTo(3);
stringBuilder
.append(new String(chunk, StandardCharsets.UTF_8));
assertThat(is.read(chunk)).isEqualTo(2);
stringBuilder
.append(new String(chunk,0, 2, StandardCharsets.UTF_8));
assertThat(is.read()).isEqualTo(-1);
assertThat(stringBuilder.toString()).isEqualTo("foobarbaz");
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
@Test
void cancel() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
Flow.Publisher<byte[]> flowPublisher = OutputStreamPublisher.create(outputStream -> {
assertThatIOException()
.isThrownBy(() -> {
outputStream.write(FOO);
outputStream.flush();
outputStream.write(BAR);
outputStream.flush();
outputStream.write(BAZ);
outputStream.flush();
})
.withMessage("Subscription has been terminated");
latch.countDown();
}, this.byteMapper, this.executor);
Flux<String> flux = toString(flowPublisher);
List<String> discarded = new ArrayList<>();
try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), discarded::add, 1)) {
byte[] chunk = new byte[3];
assertThat(is.read(chunk)).isEqualTo(3);
assertThat(chunk).containsExactly(FOO);
}
catch (IOException e) {
throw new RuntimeException(e);
}
latch.await();
assertThat(discarded).containsExactly("bar");
}
@Test
void closed() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
Flow.Publisher<byte[]> flowPublisher = OutputStreamPublisher.create(outputStream -> {
OutputStreamWriter writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8);
writer.write("foo");
writer.close();
assertThatIOException().isThrownBy(() -> writer.write("bar"))
.withMessage("Stream closed");
latch.countDown();
}, this.byteMapper, this.executor);
Flux<String> flux = toString(flowPublisher);
try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), ig -> {}, 1)) {
byte[] chunk = new byte[3];
assertThat(is.read(chunk)).isEqualTo(3);
assertThat(chunk).containsExactly(FOO);
assertThat(is.read(chunk)).isEqualTo(-1);
}
catch (IOException e) {
throw new RuntimeException(e);
}
latch.await();
}
@Test
void mapperThrowsException() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
Flow.Publisher<byte[]> flowPublisher = OutputStreamPublisher.create(outputStream -> {
outputStream.write(FOO);
outputStream.flush();
assertThatIOException().isThrownBy(() -> {
outputStream.write(BAR);
outputStream.flush();
}).withMessage("Subscription has been terminated");
latch.countDown();
}, this.byteMapper, this.executor);
Throwable ex = null;
StringBuilder stringBuilder = new StringBuilder();
try (InputStream is = InputStreamSubscriber.subscribeTo(flowPublisher, (s) -> {
throw new NullPointerException("boom");
}, ig -> {}, 1)) {
byte[] chunk = new byte[3];
stringBuilder
.append(new String(new byte[]{(byte)is.read()}, StandardCharsets.UTF_8));
assertThat(is.read(chunk)).isEqualTo(3);
stringBuilder
.append(new String(chunk, StandardCharsets.UTF_8));
assertThat(is.read(chunk)).isEqualTo(3);
stringBuilder
.append(new String(chunk, StandardCharsets.UTF_8));
assertThat(is.read(chunk)).isEqualTo(2);
stringBuilder
.append(new String(chunk,0, 2, StandardCharsets.UTF_8));
assertThat(is.read()).isEqualTo(-1);
}
catch (Throwable e) {
ex = e;
}
latch.await();
assertThat(stringBuilder.toString()).isEqualTo("");
assertThat(ex).hasMessage("boom");
}
private static Flux<String> toString(Flow.Publisher<byte[]> flowPublisher) {
return Flux.from(FlowAdapters.toPublisher(flowPublisher))
.map(bytes -> new String(bytes, StandardCharsets.UTF_8));
}
}