Introduce OutputStream BodyInserter

This commit introduces a BodyInserter that inssert any bytes written to
an output stream to the body of an output message.

Closes gh-31184
This commit is contained in:
Arjen Poutsma 2023-09-11 11:00:11 +02:00
parent 913dc86e18
commit 59d123a18e
5 changed files with 634 additions and 1 deletions

View File

@ -33,6 +33,7 @@ import java.nio.file.StandardOpenOption;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
@ -41,6 +42,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
@ -66,6 +68,9 @@ public abstract class DataBufferUtils {
private static final Consumer<DataBuffer> RELEASE_CONSUMER = DataBufferUtils::release;
private static final int DEFAULT_CHUNK_SIZE = 1024;
//---------------------------------------------------------------------
// Reading
@ -405,6 +410,83 @@ public abstract class DataBufferUtils {
}
/**
* Create a new {@code Publisher<DataBuffer>} based on bytes written to a
* {@code OutputStream}.
* <ul>
* <li>The parameter {@code outputStreamConsumer} is invoked once per
* subscription of the returned {@code Publisher}, when the first
* item is
* {@linkplain Subscription#request(long) requested}.</li>
* <li>{@link OutputStream#write(byte[], int, int) OutputStream.write()}
* invocations made by {@code outputStreamConsumer} are buffered until they
* exceed the default chunk size of 1024, or when the stream is
* {@linkplain OutputStream#flush() flushed} and then result in a
* {@linkplain Subscriber#onNext(Object) published} item
* if there is {@linkplain Subscription#request(long) demand}.</li>
* <li>If there is <em>no demand</em>, {@code OutputStream.write()} will block
* until there is.</li>
* <li>If the subscription is {@linkplain Subscription#cancel() cancelled},
* {@code OutputStream.write()} will throw a {@code IOException}.</li>
* <li>The subscription is
* {@linkplain Subscriber#onComplete() completed} when
* {@code outputStreamHandler} completes.</li>
* <li>Any exceptions thrown from {@code outputStreamHandler} will
* be dispatched to the {@linkplain Subscriber#onError(Throwable) Subscriber}.
* </ul>
* @param outputStreamConsumer invoked when the first buffer is requested
* @param executor used to invoke the {@code outputStreamHandler}
* @return a {@code Publisher<DataBuffer>} based on bytes written by
* {@code outputStreamHandler}
*/
public static Publisher<DataBuffer> outputStreamPublisher(Consumer<OutputStream> outputStreamConsumer,
DataBufferFactory bufferFactory, Executor executor) {
return outputStreamPublisher(outputStreamConsumer, bufferFactory, executor, DEFAULT_CHUNK_SIZE);
}
/**
* Creates a new {@code Publisher<DataBuffer>} based on bytes written to a
* {@code OutputStream}.
* <ul>
* <li>The parameter {@code outputStreamConsumer} is invoked once per
* subscription of the returned {@code Publisher}, when the first
* item is
* {@linkplain Subscription#request(long) requested}.</li>
* <li>{@link OutputStream#write(byte[], int, int) OutputStream.write()}
* invocations made by {@code outputStreamHandler} are buffered until they
* reach or exceed {@code chunkSize}, or when the stream is
* {@linkplain OutputStream#flush() flushed} and then result in a
* {@linkplain Subscriber#onNext(Object) published} item
* if there is {@linkplain Subscription#request(long) demand}.</li>
* <li>If there is <em>no demand</em>, {@code OutputStream.write()} will block
* until there is.</li>
* <li>If the subscription is {@linkplain Subscription#cancel() cancelled},
* {@code OutputStream.write()} will throw a {@code IOException}.</li>
* <li>The subscription is
* {@linkplain Subscriber#onComplete() completed} when
* {@code outputStreamHandler} completes.</li>
* <li>Any exceptions thrown from {@code outputStreamHandler} will
* be dispatched to the {@linkplain Subscriber#onError(Throwable) Subscriber}.
* </ul>
* @param outputStreamConsumer invoked when the first buffer is requested
* @param executor used to invoke the {@code outputStreamHandler}
* @param chunkSize minimum size of the buffer produced by the publisher
* @return a {@code Publisher<DataBuffer>} based on bytes written by
* {@code outputStreamHandler}
*/
public static Publisher<DataBuffer> outputStreamPublisher(Consumer<OutputStream> outputStreamConsumer,
DataBufferFactory bufferFactory, Executor executor, int chunkSize) {
Assert.notNull(outputStreamConsumer, "OutputStreamConsumer must not be null");
Assert.notNull(bufferFactory, "BufferFactory must not be null");
Assert.notNull(executor, "Executor must not be null");
Assert.isTrue(chunkSize > 0, "Chunk size must be > 0");
return new OutputStreamPublisher(outputStreamConsumer, bufferFactory, executor, chunkSize);
}
//---------------------------------------------------------------------
// Various
//---------------------------------------------------------------------

View File

@ -0,0 +1,354 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.core.io.buffer;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.LockSupport;
import java.util.function.Consumer;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.springframework.lang.Nullable;
/**
* Bridges between {@link OutputStream} and
* {@link Publisher Publisher&lt;DataBuffer&gt;}.
*
* <p>Note that this class has a near duplicate in
* {@link org.springframework.http.client.OutputStreamPublisher}.
*
* @author Oleh Dokuka
* @author Arjen Poutsma
* @since 6.1
*/
final class OutputStreamPublisher implements Publisher<DataBuffer> {
private final Consumer<OutputStream> outputStreamConsumer;
private final DataBufferFactory bufferFactory;
private final Executor executor;
private final int chunkSize;
public OutputStreamPublisher(Consumer<OutputStream> outputStreamConsumer, DataBufferFactory bufferFactory,
Executor executor, int chunkSize) {
this.outputStreamConsumer = outputStreamConsumer;
this.bufferFactory = bufferFactory;
this.executor = executor;
this.chunkSize = chunkSize;
}
@Override
public void subscribe(Subscriber<? super DataBuffer> subscriber) {
Objects.requireNonNull(subscriber, "Subscriber must not be null");
OutputStreamSubscription subscription = new OutputStreamSubscription(subscriber, this.outputStreamConsumer,
this.bufferFactory, this.chunkSize);
subscriber.onSubscribe(subscription);
this.executor.execute(subscription::invokeHandler);
}
private static final class OutputStreamSubscription extends OutputStream implements Subscription {
private static final Object READY = new Object();
private final Subscriber<? super DataBuffer> actual;
private final Consumer<OutputStream> outputStreamHandler;
private final DataBufferFactory bufferFactory;
private final int chunkSize;
private final AtomicLong requested = new AtomicLong();
private final AtomicReference<Object> parkedThread = new AtomicReference<>();
@Nullable
private volatile Throwable error;
private long produced;
public OutputStreamSubscription(Subscriber<? super DataBuffer> actual,
Consumer<OutputStream> outputStreamConsumer, DataBufferFactory bufferFactory, int chunkSize) {
this.actual = actual;
this.outputStreamHandler = outputStreamConsumer;
this.bufferFactory = bufferFactory;
this.chunkSize = chunkSize;
}
@Override
public void write(int b) throws IOException {
checkDemandAndAwaitIfNeeded();
DataBuffer next = this.bufferFactory.allocateBuffer(1);
next.write((byte) b);
this.actual.onNext(next);
this.produced++;
}
@Override
public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
checkDemandAndAwaitIfNeeded();
DataBuffer next = this.bufferFactory.allocateBuffer(len);
next.write(b, off, len);
this.actual.onNext(next);
this.produced++;
}
private void checkDemandAndAwaitIfNeeded() throws IOException {
long r = this.requested.get();
if (isTerminated(r) || isCancelled(r)) {
throw new IOException("Subscription has been terminated");
}
long p = this.produced;
if (p == r) {
if (p > 0) {
r = tryProduce(p);
this.produced = 0;
}
while (true) {
if (isTerminated(r) || isCancelled(r)) {
throw new IOException("Subscription has been terminated");
}
if (r != 0) {
return;
}
await();
r = this.requested.get();
}
}
}
private void invokeHandler() {
// assume sync write within try-with-resource block
// use BufferedOutputStream, so that written bytes are buffered
// before publishing as byte buffer
try (OutputStream outputStream = new BufferedOutputStream(this, this.chunkSize)) {
this.outputStreamHandler.accept(outputStream);
}
catch (Exception ex) {
long previousState = tryTerminate();
if (isCancelled(previousState)) {
return;
}
if (isTerminated(previousState)) {
// failure due to illegal requestN
this.actual.onError(this.error);
return;
}
this.actual.onError(ex);
return;
}
long previousState = tryTerminate();
if (isCancelled(previousState)) {
return;
}
if (isTerminated(previousState)) {
// failure due to illegal requestN
this.actual.onError(this.error);
return;
}
this.actual.onComplete();
}
@Override
public void request(long n) {
if (n <= 0) {
this.error = new IllegalArgumentException("request should be a positive number");
long previousState = tryTerminate();
if (isTerminated(previousState) || isCancelled(previousState)) {
return;
}
if (previousState > 0) {
// error should eventually be observed and propagated
return;
}
// resume parked thread, so it can observe error and propagate it
resume();
return;
}
if (addCap(n) == 0) {
// resume parked thread so it can continue the work
resume();
}
}
@Override
public void cancel() {
long previousState = tryCancel();
if (isCancelled(previousState) || previousState > 0) {
return;
}
// resume parked thread, so it can be unblocked and close all the resources
resume();
}
private void await() {
Thread toUnpark = Thread.currentThread();
while (true) {
Object current = this.parkedThread.get();
if (current == READY) {
break;
}
if (current != null && current != toUnpark) {
throw new IllegalStateException("Only one (Virtual)Thread can await!");
}
if (this.parkedThread.compareAndSet(null, toUnpark)) {
LockSupport.park();
// we don't just break here because park() can wake up spuriously
// if we got a proper resume, get() == READY and the loop will quit above
}
}
// clear the resume indicator so that the next await call will park without a resume()
this.parkedThread.lazySet(null);
}
private void resume() {
if (this.parkedThread.get() != READY) {
Object old = this.parkedThread.getAndSet(READY);
if (old != READY) {
LockSupport.unpark((Thread)old);
}
}
}
private long tryCancel() {
while (true) {
long r = this.requested.get();
if (isCancelled(r)) {
return r;
}
if (this.requested.compareAndSet(r, Long.MIN_VALUE)) {
return r;
}
}
}
private long tryTerminate() {
while (true) {
long r = this.requested.get();
if (isCancelled(r) || isTerminated(r)) {
return r;
}
if (this.requested.compareAndSet(r, Long.MIN_VALUE | Long.MAX_VALUE)) {
return r;
}
}
}
private long tryProduce(long n) {
while (true) {
long current = this.requested.get();
if (isTerminated(current) || isCancelled(current)) {
return current;
}
if (current == Long.MAX_VALUE) {
return Long.MAX_VALUE;
}
long update = current - n;
if (update < 0L) {
update = 0L;
}
if (this.requested.compareAndSet(current, update)) {
return update;
}
}
}
private long addCap(long n) {
while (true) {
long r = this.requested.get();
if (isTerminated(r) || isCancelled(r) || r == Long.MAX_VALUE) {
return r;
}
long u = addCap(r, n);
if (this.requested.compareAndSet(r, u)) {
return r;
}
}
}
private static boolean isTerminated(long state) {
return state == (Long.MIN_VALUE | Long.MAX_VALUE);
}
private static boolean isCancelled(long state) {
return state == Long.MIN_VALUE;
}
private static long addCap(long a, long b) {
long res = a + b;
if (res < 0L) {
return Long.MAX_VALUE;
}
return res;
}
}
}

View File

@ -18,6 +18,7 @@ package org.springframework.core.io.buffer;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousFileChannel;
@ -34,11 +35,13 @@ import java.nio.file.StandardOpenOption;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import org.junit.jupiter.api.Test;
import org.mockito.stubbing.Answer;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
@ -52,6 +55,8 @@ import org.springframework.core.io.Resource;
import org.springframework.core.testfixture.io.buffer.AbstractDataBufferAllocatingTests;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIOException;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.BDDMockito.given;
@ -543,6 +548,147 @@ class DataBufferUtilsTests extends AbstractDataBufferAllocatingTests {
assertThat(written).contains("foobar");
}
@ParameterizedDataBufferAllocatingTest
void outputStreamPublisher(DataBufferFactory bufferFactory) {
super.bufferFactory = bufferFactory;
byte[] foo = "foo".getBytes(StandardCharsets.UTF_8);
byte[] bar = "bar".getBytes(StandardCharsets.UTF_8);
byte[] baz = "baz".getBytes(StandardCharsets.UTF_8);
Publisher<DataBuffer> publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
try {
outputStream.write(foo);
outputStream.write(bar);
outputStream.write(baz);
}
catch (IOException ex) {
fail(ex.getMessage(), ex);
}
}, super.bufferFactory, Executors.newSingleThreadExecutor());
StepVerifier.create(publisher)
.consumeNextWith(stringConsumer("foobarbaz"))
.verifyComplete();
}
@ParameterizedDataBufferAllocatingTest
void outputStreamPublisherFlush(DataBufferFactory bufferFactory) {
super.bufferFactory = bufferFactory;
byte[] foo = "foo".getBytes(StandardCharsets.UTF_8);
byte[] bar = "bar".getBytes(StandardCharsets.UTF_8);
byte[] baz = "baz".getBytes(StandardCharsets.UTF_8);
Publisher<DataBuffer> publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
try {
outputStream.write(foo);
outputStream.flush();
outputStream.write(bar);
outputStream.flush();
outputStream.write(baz);
outputStream.flush();
}
catch (IOException ex) {
fail(ex.getMessage(), ex);
}
}, super.bufferFactory, Executors.newSingleThreadExecutor());
StepVerifier.create(publisher)
.consumeNextWith(stringConsumer("foo"))
.consumeNextWith(stringConsumer("bar"))
.consumeNextWith(stringConsumer("baz"))
.verifyComplete();
}
@ParameterizedDataBufferAllocatingTest
void outputStreamPublisherChunkSize(DataBufferFactory bufferFactory) {
super.bufferFactory = bufferFactory;
byte[] foo = "foo".getBytes(StandardCharsets.UTF_8);
byte[] bar = "bar".getBytes(StandardCharsets.UTF_8);
byte[] baz = "baz".getBytes(StandardCharsets.UTF_8);
Publisher<DataBuffer> publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
try {
outputStream.write(foo);
outputStream.write(bar);
outputStream.write(baz);
}
catch (IOException ex) {
fail(ex.getMessage(), ex);
}
}, super.bufferFactory, Executors.newSingleThreadExecutor(), 3);
StepVerifier.create(publisher)
.consumeNextWith(stringConsumer("foo"))
.consumeNextWith(stringConsumer("bar"))
.consumeNextWith(stringConsumer("baz"))
.verifyComplete();
}
@ParameterizedDataBufferAllocatingTest
void outputStreamPublisherCancel(DataBufferFactory bufferFactory) throws InterruptedException {
super.bufferFactory = bufferFactory;
byte[] foo = "foo".getBytes(StandardCharsets.UTF_8);
byte[] bar = "bar".getBytes(StandardCharsets.UTF_8);
CountDownLatch latch = new CountDownLatch(1);
Publisher<DataBuffer> publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
try {
assertThatIOException()
.isThrownBy(() -> {
outputStream.write(foo);
outputStream.flush();
outputStream.write(bar);
outputStream.flush();
})
.withMessage("Subscription has been terminated");
}
finally {
latch.countDown();
}
}, super.bufferFactory, Executors.newSingleThreadExecutor());
StepVerifier.create(publisher, 1)
.consumeNextWith(stringConsumer("foo"))
.thenCancel()
.verify();
latch.await();
}
@ParameterizedDataBufferAllocatingTest
void outputStreamPublisherClosed(DataBufferFactory bufferFactory) throws InterruptedException {
super.bufferFactory = bufferFactory;
CountDownLatch latch = new CountDownLatch(1);
Publisher<DataBuffer> publisher = DataBufferUtils.outputStreamPublisher(outputStream -> {
try {
OutputStreamWriter writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8);
writer.write("foo");
writer.close();
assertThatIOException().isThrownBy(() -> writer.write("bar"))
.withMessage("Stream closed");
}
catch (IOException ex) {
fail(ex.getMessage(), ex);
}
finally {
latch.countDown();
}
}, super.bufferFactory, Executors.newSingleThreadExecutor());
StepVerifier.create(publisher)
.consumeNextWith(stringConsumer("foo"))
.verifyComplete();
latch.await();
}
@ParameterizedDataBufferAllocatingTest
void readAndWriteByteChannel(DataBufferFactory bufferFactory) throws Exception {
super.bufferFactory = bufferFactory;

View File

@ -32,7 +32,10 @@ import org.springframework.util.Assert;
/**
* Bridges between {@link OutputStream} and
* {@link Flow.Publisher Flow.Publisher&lt;T&gt;}.
*
* <p>Note that this class has a near duplicate in
* {@link org.springframework.core.io.buffer.OutputStreamPublisher}.
*
* @author Oleh Dokuka
* @author Arjen Poutsma
* @since 6.1

View File

@ -16,7 +16,10 @@
package org.springframework.web.reactive.function;
import java.io.OutputStream;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
@ -27,6 +30,8 @@ import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpEntity;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpOutputMessage;
@ -358,6 +363,49 @@ public abstract class BodyInserters {
return (outputMessage, context) -> outputMessage.writeWith(publisher);
}
/**
* Inserter based on bytes written to a {@code OutputStream}.
* @param outputStreamConsumer invoked with an {@link OutputStream} that
* writes to the output message
* @param executor used to invoke the {@code outputStreamHandler} on a
* separate thread
* @return an inserter that writes what is written to the output stream
* @since 6.1
* @see DataBufferUtils#outputStreamPublisher(Consumer, DataBufferFactory, Executor)
*/
public static <T extends Publisher<DataBuffer>> BodyInserter<T, ReactiveHttpOutputMessage> fromOutputStream(
Consumer<OutputStream> outputStreamConsumer, Executor executor) {
Assert.notNull(outputStreamConsumer, "OutputStreamConsumer must not be null");
Assert.notNull(executor, "Executor must not be null");
return (outputMessage, context) -> outputMessage.writeWith(
DataBufferUtils.outputStreamPublisher(outputStreamConsumer, outputMessage.bufferFactory(), executor));
}
/**
* Inserter based on bytes written to a {@code OutputStream}.
* @param outputStreamConsumer invoked with an {@link OutputStream} that
* writes to the output message
* @param executor used to invoke the {@code outputStreamHandler} on a
* separate thread
* @param chunkSize minimum size of the buffer produced by the publisher
* @return an inserter that writes what is written to the output stream
* @since 6.1
* @see DataBufferUtils#outputStreamPublisher(Consumer, DataBufferFactory, Executor, int)
*/
public static <T extends Publisher<DataBuffer>> BodyInserter<T, ReactiveHttpOutputMessage> fromOutputStream(
Consumer<OutputStream> outputStreamConsumer, Executor executor, int chunkSize) {
Assert.notNull(outputStreamConsumer, "OutputStreamConsumer must not be null");
Assert.notNull(executor, "Executor must not be null");
Assert.isTrue(chunkSize > 0, "Chunk size must be > 0");
return (outputMessage, context) -> outputMessage.writeWith(
DataBufferUtils.outputStreamPublisher(outputStreamConsumer, outputMessage.bufferFactory(), executor,
chunkSize));
}
private static <M extends ReactiveHttpOutputMessage> Mono<Void> writeWithMessageWriters(
M outputMessage, BodyInserter.Context context, Object body, ResolvableType bodyType, @Nullable ReactiveAdapter adapter) {