Refactor PartGenerator to use isLast

This commit refactors the PartGenerator to use the newly introduced
Token::isLast property.

See gh-28006
This commit is contained in:
Arjen Poutsma 2022-04-11 16:03:42 +02:00
parent d44ba0a42b
commit e29bc3db7c
3 changed files with 50 additions and 80 deletions

View File

@ -23,6 +23,7 @@ import java.nio.file.Path;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@ -222,12 +223,30 @@ public class DefaultPartHttpMessageReader extends LoggingCodecSupport implements
return Flux.error(new DecodingException("No multipart boundary found in Content-Type: \"" +
message.getHeaders().getContentType() + "\""));
}
Flux<MultipartParser.Token> tokens = MultipartParser.parse(message.getBody(), boundary,
Flux<MultipartParser.Token> allPartsTokens = MultipartParser.parse(message.getBody(), boundary,
this.maxHeadersSize, this.headersCharset);
return PartGenerator.createParts(tokens, this.maxParts, this.maxInMemorySize, this.maxDiskUsagePerPart,
this.streaming, this.fileStorage.directory(), this.blockingOperationScheduler);
AtomicInteger partCount = new AtomicInteger();
return allPartsTokens
.windowUntil(MultipartParser.Token::isLast)
.concatMap(partsTokens -> {
if (tooManyParts(partCount)) {
return Mono.error(new DecodingException("Too many parts (" + partCount.get() + "/" +
this.maxParts + " allowed)"));
}
else {
return PartGenerator.createPart(partsTokens,
this.maxInMemorySize, this.maxDiskUsagePerPart, this.streaming,
this.fileStorage.directory(), this.blockingOperationScheduler);
}
});
});
}
private boolean tooManyParts(AtomicInteger partCount) {
int count = partCount.incrementAndGet();
return this.maxParts > 0 && count > this.maxParts;
}
}

View File

@ -30,7 +30,6 @@ import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
@ -41,10 +40,10 @@ import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;
import reactor.core.scheduler.Scheduler;
import reactor.util.context.Context;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
@ -65,13 +64,9 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private final AtomicReference<State> state = new AtomicReference<>(new InitialState());
private final AtomicInteger partCount = new AtomicInteger();
private final AtomicBoolean requestOutstanding = new AtomicBoolean();
private final FluxSink<Part> sink;
private final int maxParts;
private final MonoSink<Part> sink;
private final boolean streaming;
@ -84,11 +79,10 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private final Scheduler blockingOperationScheduler;
private PartGenerator(FluxSink<Part> sink, int maxParts, int maxInMemorySize, long maxDiskUsagePerPart,
private PartGenerator(MonoSink<Part> sink, int maxInMemorySize, long maxDiskUsagePerPart,
boolean streaming, Mono<Path> fileStorageDirectory, Scheduler blockingOperationScheduler) {
this.sink = sink;
this.maxParts = maxParts;
this.maxInMemorySize = maxInMemorySize;
this.maxDiskUsagePerPart = maxDiskUsagePerPart;
this.streaming = streaming;
@ -99,15 +93,15 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
/**
* Creates parts from a given stream of tokens.
*/
public static Flux<Part> createParts(Flux<MultipartParser.Token> tokens, int maxParts, int maxInMemorySize,
public static Mono<Part> createPart(Flux<MultipartParser.Token> tokens, int maxInMemorySize,
long maxDiskUsagePerPart, boolean streaming, Mono<Path> fileStorageDirectory,
Scheduler blockingOperationScheduler) {
return Flux.create(sink -> {
PartGenerator generator = new PartGenerator(sink, maxParts, maxInMemorySize, maxDiskUsagePerPart, streaming,
return Mono.create(sink -> {
PartGenerator generator = new PartGenerator(sink, maxInMemorySize, maxDiskUsagePerPart, streaming,
fileStorageDirectory, blockingOperationScheduler);
sink.onCancel(generator::onSinkCancel);
sink.onCancel(generator);
sink.onRequest(l -> generator.requestToken());
tokens.subscribe(generator);
});
@ -128,13 +122,6 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
this.requestOutstanding.set(false);
State state = this.state.get();
if (token instanceof MultipartParser.HeadersToken) {
// finish previous part
state.partComplete(false);
if (tooManyParts()) {
return;
}
newPart(state, token.headers());
}
else {
@ -144,11 +131,11 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private void newPart(State currentState, HttpHeaders headers) {
if (MultipartUtils.isFormField(headers)) {
changeStateInternal(new FormFieldState(headers));
changeState(currentState, new FormFieldState(headers));
requestToken();
}
else if (!this.streaming) {
changeStateInternal(new InMemoryState(headers));
changeState(currentState, new InMemoryState(headers));
requestToken();
}
else {
@ -165,7 +152,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
@Override
protected void hookOnComplete() {
this.state.get().partComplete(true);
this.state.get().onComplete();
}
@Override
@ -175,7 +162,8 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
this.sink.error(throwable);
}
private void onSinkCancel() {
@Override
public void dispose() {
changeStateInternal(DisposedState.INSTANCE);
cancel();
}
@ -211,14 +199,9 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
if (logger.isTraceEnabled()) {
logger.trace("Emitting: " + part);
}
this.sink.next(part);
this.sink.success(part);
}
void emitComplete() {
this.sink.complete();
}
void emitError(Throwable t) {
cancel();
this.sink.error(t);
@ -226,24 +209,11 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
void requestToken() {
if (upstream() != null &&
!this.sink.isCancelled() &&
this.sink.requestedFromDownstream() > 0 &&
this.requestOutstanding.compareAndSet(false, true)) {
request(1);
}
}
private boolean tooManyParts() {
int count = this.partCount.incrementAndGet();
if (this.maxParts > 0 && count > this.maxParts) {
emitError(new DecodingException("Too many parts (" + count + "/" + this.maxParts + " allowed)"));
return true;
}
else {
return false;
}
}
/**
* Represents the internal state of the {@link PartGenerator} for
* creating a single {@link Part}.
@ -273,10 +243,8 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
/**
* Invoked when all tokens for the part have been received.
* @param finalPart {@code true} if this was the last part (and
* {@link #emitComplete()} should be called; {@code false} otherwise
*/
void partComplete(boolean finalPart);
void onComplete();
/**
* Invoked when an error has been received.
@ -307,10 +275,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
}
@Override
public void partComplete(boolean finalPart) {
if (finalPart) {
emitComplete();
}
public void onComplete() {
}
@Override
@ -364,13 +329,10 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
}
@Override
public void partComplete(boolean finalPart) {
public void onComplete() {
byte[] bytes = this.value.toByteArrayUnsafe();
String value = new String(bytes, MultipartUtils.charset(this.headers));
emitPart(DefaultParts.formFieldPart(this.headers, value));
if (finalPart) {
emitComplete();
}
}
@Override
@ -410,13 +372,10 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
}
@Override
public void partComplete(boolean finalPart) {
public void onComplete() {
if (!this.bodySink.isCancelled()) {
this.bodySink.complete();
}
if (finalPart) {
emitComplete();
}
}
@Override
@ -493,11 +452,8 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
}
@Override
public void partComplete(boolean finalPart) {
public void onComplete() {
emitMemoryPart();
if (finalPart) {
emitComplete();
}
}
private void emitMemoryPart() {
@ -545,8 +501,6 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private volatile boolean completed;
private volatile boolean finalPart;
private volatile boolean releaseOnDispose = true;
@ -563,9 +517,8 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
}
@Override
public void partComplete(boolean finalPart) {
public void onComplete() {
this.completed = true;
this.finalPart = finalPart;
}
public void createFile() {
@ -597,7 +550,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
newState.writeBuffers(this.content);
if (this.completed) {
newState.partComplete(this.finalPart);
newState.onComplete();
}
}
else {
@ -665,12 +618,9 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
}
@Override
public void partComplete(boolean finalPart) {
public void onComplete() {
MultipartUtils.closeChannel(this.channel);
emitPart(DefaultParts.part(this.headers, this.file, PartGenerator.this.blockingOperationScheduler));
if (finalPart) {
emitComplete();
}
}
@Override
@ -701,8 +651,6 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private volatile boolean completed;
private volatile boolean finalPart;
public WritingFileState(CreateFileState state, Path file, WritableByteChannel channel) {
this.headers = state.headers;
@ -725,9 +673,8 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
}
@Override
public void partComplete(boolean finalPart) {
public void onComplete() {
this.completed = true;
this.finalPart = finalPart;
}
public void writeBuffer(DataBuffer dataBuffer) {
@ -752,7 +699,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
private void writeComplete() {
IdleFileState newState = new IdleFileState(this);
if (this.completed) {
newState.partComplete(this.finalPart);
newState.onComplete();
}
else if (changeState(this, newState)) {
requestToken();
@ -799,7 +746,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
}
@Override
public void partComplete(boolean finalPart) {
public void onComplete() {
}
@Override

View File

@ -118,6 +118,10 @@ class DefaultPartHttpMessageReaderTests {
Flux<Part> result = reader.read(forClass(Part.class), request, emptyMap());
StepVerifier.create(result)
.consumeNextWith(part -> {
assertThat(part.headers().getFirst("Header")).isEqualTo("Value");
part.content().subscribe(DataBufferUtils::release);
})
.expectError(DecodingException.class)
.verify();
}