Refactor PartGenerator to use isLast

This commit refactors the PartGenerator to use the newly introduced
Token::isLast property.

See gh-28006
This commit is contained in:
Arjen Poutsma 2022-04-11 16:03:42 +02:00
parent d44ba0a42b
commit e29bc3db7c
3 changed files with 50 additions and 80 deletions

View File

@ -23,6 +23,7 @@ import java.nio.file.Path;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -222,12 +223,30 @@ public class DefaultPartHttpMessageReader extends LoggingCodecSupport implements
return Flux.error(new DecodingException("No multipart boundary found in Content-Type: \"" + return Flux.error(new DecodingException("No multipart boundary found in Content-Type: \"" +
message.getHeaders().getContentType() + "\"")); message.getHeaders().getContentType() + "\""));
} }
Flux<MultipartParser.Token> tokens = MultipartParser.parse(message.getBody(), boundary, Flux<MultipartParser.Token> allPartsTokens = MultipartParser.parse(message.getBody(), boundary,
this.maxHeadersSize, this.headersCharset); this.maxHeadersSize, this.headersCharset);
return PartGenerator.createParts(tokens, this.maxParts, this.maxInMemorySize, this.maxDiskUsagePerPart, AtomicInteger partCount = new AtomicInteger();
this.streaming, this.fileStorage.directory(), this.blockingOperationScheduler); return allPartsTokens
.windowUntil(MultipartParser.Token::isLast)
.concatMap(partsTokens -> {
if (tooManyParts(partCount)) {
return Mono.error(new DecodingException("Too many parts (" + partCount.get() + "/" +
this.maxParts + " allowed)"));
}
else {
return PartGenerator.createPart(partsTokens,
this.maxInMemorySize, this.maxDiskUsagePerPart, this.streaming,
this.fileStorage.directory(), this.blockingOperationScheduler);
}
});
}); });
} }
private boolean tooManyParts(AtomicInteger partCount) {
int count = partCount.incrementAndGet();
return this.maxParts > 0 && count > this.maxParts;
}
} }

View File

@ -30,7 +30,6 @@ import java.util.List;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -41,10 +40,10 @@ import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink; import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.util.context.Context; import reactor.util.context.Context;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.DataBufferUtils;
@ -65,13 +64,9 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private final AtomicReference<State> state = new AtomicReference<>(new InitialState()); private final AtomicReference<State> state = new AtomicReference<>(new InitialState());
private final AtomicInteger partCount = new AtomicInteger();
private final AtomicBoolean requestOutstanding = new AtomicBoolean(); private final AtomicBoolean requestOutstanding = new AtomicBoolean();
private final FluxSink<Part> sink; private final MonoSink<Part> sink;
private final int maxParts;
private final boolean streaming; private final boolean streaming;
@ -84,11 +79,10 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private final Scheduler blockingOperationScheduler; private final Scheduler blockingOperationScheduler;
private PartGenerator(FluxSink<Part> sink, int maxParts, int maxInMemorySize, long maxDiskUsagePerPart, private PartGenerator(MonoSink<Part> sink, int maxInMemorySize, long maxDiskUsagePerPart,
boolean streaming, Mono<Path> fileStorageDirectory, Scheduler blockingOperationScheduler) { boolean streaming, Mono<Path> fileStorageDirectory, Scheduler blockingOperationScheduler) {
this.sink = sink; this.sink = sink;
this.maxParts = maxParts;
this.maxInMemorySize = maxInMemorySize; this.maxInMemorySize = maxInMemorySize;
this.maxDiskUsagePerPart = maxDiskUsagePerPart; this.maxDiskUsagePerPart = maxDiskUsagePerPart;
this.streaming = streaming; this.streaming = streaming;
@ -99,15 +93,15 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
/** /**
* Creates parts from a given stream of tokens. * Creates parts from a given stream of tokens.
*/ */
public static Flux<Part> createParts(Flux<MultipartParser.Token> tokens, int maxParts, int maxInMemorySize, public static Mono<Part> createPart(Flux<MultipartParser.Token> tokens, int maxInMemorySize,
long maxDiskUsagePerPart, boolean streaming, Mono<Path> fileStorageDirectory, long maxDiskUsagePerPart, boolean streaming, Mono<Path> fileStorageDirectory,
Scheduler blockingOperationScheduler) { Scheduler blockingOperationScheduler) {
return Flux.create(sink -> { return Mono.create(sink -> {
PartGenerator generator = new PartGenerator(sink, maxParts, maxInMemorySize, maxDiskUsagePerPart, streaming, PartGenerator generator = new PartGenerator(sink, maxInMemorySize, maxDiskUsagePerPart, streaming,
fileStorageDirectory, blockingOperationScheduler); fileStorageDirectory, blockingOperationScheduler);
sink.onCancel(generator::onSinkCancel); sink.onCancel(generator);
sink.onRequest(l -> generator.requestToken()); sink.onRequest(l -> generator.requestToken());
tokens.subscribe(generator); tokens.subscribe(generator);
}); });
@ -128,13 +122,6 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
this.requestOutstanding.set(false); this.requestOutstanding.set(false);
State state = this.state.get(); State state = this.state.get();
if (token instanceof MultipartParser.HeadersToken) { if (token instanceof MultipartParser.HeadersToken) {
// finish previous part
state.partComplete(false);
if (tooManyParts()) {
return;
}
newPart(state, token.headers()); newPart(state, token.headers());
} }
else { else {
@ -144,11 +131,11 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private void newPart(State currentState, HttpHeaders headers) { private void newPart(State currentState, HttpHeaders headers) {
if (MultipartUtils.isFormField(headers)) { if (MultipartUtils.isFormField(headers)) {
changeStateInternal(new FormFieldState(headers)); changeState(currentState, new FormFieldState(headers));
requestToken(); requestToken();
} }
else if (!this.streaming) { else if (!this.streaming) {
changeStateInternal(new InMemoryState(headers)); changeState(currentState, new InMemoryState(headers));
requestToken(); requestToken();
} }
else { else {
@ -165,7 +152,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
@Override @Override
protected void hookOnComplete() { protected void hookOnComplete() {
this.state.get().partComplete(true); this.state.get().onComplete();
} }
@Override @Override
@ -175,7 +162,8 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
this.sink.error(throwable); this.sink.error(throwable);
} }
private void onSinkCancel() { @Override
public void dispose() {
changeStateInternal(DisposedState.INSTANCE); changeStateInternal(DisposedState.INSTANCE);
cancel(); cancel();
} }
@ -211,14 +199,9 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Emitting: " + part); logger.trace("Emitting: " + part);
} }
this.sink.next(part); this.sink.success(part);
} }
void emitComplete() {
this.sink.complete();
}
void emitError(Throwable t) { void emitError(Throwable t) {
cancel(); cancel();
this.sink.error(t); this.sink.error(t);
@ -226,24 +209,11 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
void requestToken() { void requestToken() {
if (upstream() != null && if (upstream() != null &&
!this.sink.isCancelled() &&
this.sink.requestedFromDownstream() > 0 &&
this.requestOutstanding.compareAndSet(false, true)) { this.requestOutstanding.compareAndSet(false, true)) {
request(1); request(1);
} }
} }
private boolean tooManyParts() {
int count = this.partCount.incrementAndGet();
if (this.maxParts > 0 && count > this.maxParts) {
emitError(new DecodingException("Too many parts (" + count + "/" + this.maxParts + " allowed)"));
return true;
}
else {
return false;
}
}
/** /**
* Represents the internal state of the {@link PartGenerator} for * Represents the internal state of the {@link PartGenerator} for
* creating a single {@link Part}. * creating a single {@link Part}.
@ -273,10 +243,8 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
/** /**
* Invoked when all tokens for the part have been received. * Invoked when all tokens for the part have been received.
* @param finalPart {@code true} if this was the last part (and
* {@link #emitComplete()} should be called; {@code false} otherwise
*/ */
void partComplete(boolean finalPart); void onComplete();
/** /**
* Invoked when an error has been received. * Invoked when an error has been received.
@ -307,10 +275,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
} }
@Override @Override
public void partComplete(boolean finalPart) { public void onComplete() {
if (finalPart) {
emitComplete();
}
} }
@Override @Override
@ -364,13 +329,10 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
} }
@Override @Override
public void partComplete(boolean finalPart) { public void onComplete() {
byte[] bytes = this.value.toByteArrayUnsafe(); byte[] bytes = this.value.toByteArrayUnsafe();
String value = new String(bytes, MultipartUtils.charset(this.headers)); String value = new String(bytes, MultipartUtils.charset(this.headers));
emitPart(DefaultParts.formFieldPart(this.headers, value)); emitPart(DefaultParts.formFieldPart(this.headers, value));
if (finalPart) {
emitComplete();
}
} }
@Override @Override
@ -410,13 +372,10 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
} }
@Override @Override
public void partComplete(boolean finalPart) { public void onComplete() {
if (!this.bodySink.isCancelled()) { if (!this.bodySink.isCancelled()) {
this.bodySink.complete(); this.bodySink.complete();
} }
if (finalPart) {
emitComplete();
}
} }
@Override @Override
@ -493,11 +452,8 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
} }
@Override @Override
public void partComplete(boolean finalPart) { public void onComplete() {
emitMemoryPart(); emitMemoryPart();
if (finalPart) {
emitComplete();
}
} }
private void emitMemoryPart() { private void emitMemoryPart() {
@ -545,8 +501,6 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private volatile boolean completed; private volatile boolean completed;
private volatile boolean finalPart;
private volatile boolean releaseOnDispose = true; private volatile boolean releaseOnDispose = true;
@ -563,9 +517,8 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
} }
@Override @Override
public void partComplete(boolean finalPart) { public void onComplete() {
this.completed = true; this.completed = true;
this.finalPart = finalPart;
} }
public void createFile() { public void createFile() {
@ -597,7 +550,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
newState.writeBuffers(this.content); newState.writeBuffers(this.content);
if (this.completed) { if (this.completed) {
newState.partComplete(this.finalPart); newState.onComplete();
} }
} }
else { else {
@ -665,12 +618,9 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
} }
@Override @Override
public void partComplete(boolean finalPart) { public void onComplete() {
MultipartUtils.closeChannel(this.channel); MultipartUtils.closeChannel(this.channel);
emitPart(DefaultParts.part(this.headers, this.file, PartGenerator.this.blockingOperationScheduler)); emitPart(DefaultParts.part(this.headers, this.file, PartGenerator.this.blockingOperationScheduler));
if (finalPart) {
emitComplete();
}
} }
@Override @Override
@ -701,8 +651,6 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private volatile boolean completed; private volatile boolean completed;
private volatile boolean finalPart;
public WritingFileState(CreateFileState state, Path file, WritableByteChannel channel) { public WritingFileState(CreateFileState state, Path file, WritableByteChannel channel) {
this.headers = state.headers; this.headers = state.headers;
@ -725,9 +673,8 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
} }
@Override @Override
public void partComplete(boolean finalPart) { public void onComplete() {
this.completed = true; this.completed = true;
this.finalPart = finalPart;
} }
public void writeBuffer(DataBuffer dataBuffer) { public void writeBuffer(DataBuffer dataBuffer) {
@ -752,7 +699,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private void writeComplete() { private void writeComplete() {
IdleFileState newState = new IdleFileState(this); IdleFileState newState = new IdleFileState(this);
if (this.completed) { if (this.completed) {
newState.partComplete(this.finalPart); newState.onComplete();
} }
else if (changeState(this, newState)) { else if (changeState(this, newState)) {
requestToken(); requestToken();
@ -799,7 +746,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
} }
@Override @Override
public void partComplete(boolean finalPart) { public void onComplete() {
} }
@Override @Override

View File

@ -118,6 +118,10 @@ class DefaultPartHttpMessageReaderTests {
Flux<Part> result = reader.read(forClass(Part.class), request, emptyMap()); Flux<Part> result = reader.read(forClass(Part.class), request, emptyMap());
StepVerifier.create(result) StepVerifier.create(result)
.consumeNextWith(part -> {
assertThat(part.headers().getFirst("Header")).isEqualTo("Value");
part.content().subscribe(DataBufferUtils::release);
})
.expectError(DecodingException.class) .expectError(DecodingException.class)
.verify(); .verify();
} }