diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartParser.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartParser.java index d797b99f4b1..ff1344424aa 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartParser.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartParser.java @@ -17,8 +17,12 @@ package org.springframework.http.codec.multipart; import java.nio.charset.Charset; +import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; import java.util.List; +import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -477,11 +481,14 @@ final class MultipartParser extends BaseSubscriber { private final DataBufferUtils.Matcher boundary; - private final AtomicReference previous = new AtomicReference<>(); + private final int boundaryLength; + + private final Deque queue = new ConcurrentLinkedDeque<>(); public BodyState() { - this.boundary = DataBufferUtils.matcher( - MultipartUtils.concat(CR_LF, TWO_HYPHENS, MultipartParser.this.boundary)); + byte[] delimiter = MultipartUtils.concat(CR_LF, TWO_HYPHENS, MultipartParser.this.boundary); + this.boundary = DataBufferUtils.matcher(delimiter); + this.boundaryLength = delimiter.length; } /** @@ -499,31 +506,38 @@ final class MultipartParser extends BaseSubscriber { if (logger.isTraceEnabled()) { logger.trace("Boundary found @" + endIdx + " in " + buffer); } - int len = endIdx - buffer.readPosition() - this.boundary.delimiter().length + 1; + int len = endIdx - buffer.readPosition() - this.boundaryLength + 1; if (len > 0) { - // buffer contains complete delimiter, let's slice it and flush it + // whole boundary in buffer. + // slice off the body part, and flush DataBuffer body = buffer.retainedSlice(buffer.readPosition(), len); enqueue(body); - enqueue(null); + flush(); } else if (len < 0) { - // buffer starts with the end of the delimiter, let's slice the previous buffer and flush it - DataBuffer previous = this.previous.get(); - int prevLen = previous.readableByteCount() + len; - if (prevLen > 0) { - DataBuffer body = previous.retainedSlice(previous.readPosition(), prevLen); - DataBufferUtils.release(previous); - this.previous.set(body); - enqueue(null); - } - else { - DataBufferUtils.release(previous); - this.previous.set(null); + // boundary spans multiple buffers, and we've just found the end + // iterate over buffers in reverse order + DataBuffer prev; + while ((prev = this.queue.pollLast()) != null) { + int prevLen = prev.readableByteCount() + len; + if (prevLen > 0) { + // slice body part of previous buffer, and flush it + DataBuffer body = prev.retainedSlice(prev.readPosition(), prevLen); + DataBufferUtils.release(prev); + enqueue(body); + flush(); + break; + } + else { + // previous buffer only contains boundary bytes + DataBufferUtils.release(prev); + len += prev.readableByteCount(); + } } } - else /* if (sliceLength == 0) */ { - // buffer starts with complete delimiter, flush out the previous buffer - enqueue(null); + else /* if (len == 0) */ { + // buffer starts with complete delimiter, flush out the previous buffers + flush(); } DataBuffer remainder = MultipartUtils.sliceFrom(buffer, endIdx); @@ -538,13 +552,32 @@ final class MultipartParser extends BaseSubscriber { } /** - * Stores the given buffer and sends out the previous buffer. + * Store the given buffer. Emit buffers that cannot contain boundary bytes, + * by iterating over the queue in reverse order, and summing buffer sizes. + * The first buffer that passes the boundary length and subsequent buffers + * are emitted (in the correct, non-reverse order). */ - private void enqueue(@Nullable DataBuffer buf) { - DataBuffer previous = this.previous.getAndSet(buf); - if (previous != null) { - emitBody(previous); + private void enqueue(DataBuffer buf) { + this.queue.add(buf); + + int len = 0; + Deque emit = new ArrayDeque<>(); + for (Iterator iterator = this.queue.descendingIterator(); iterator.hasNext(); ) { + DataBuffer previous = iterator.next(); + if (len > this.boundaryLength) { + // addFirst to negate iterating in reverse order + emit.addFirst(previous); + iterator.remove(); + } + len += previous.readableByteCount(); } + + emit.forEach(MultipartParser.this::emitBody); + } + + private void flush() { + this.queue.forEach(MultipartParser.this::emitBody); + this.queue.clear(); } @Override @@ -556,10 +589,8 @@ final class MultipartParser extends BaseSubscriber { @Override public void dispose() { - DataBuffer previous = this.previous.getAndSet(null); - if (previous != null) { - DataBufferUtils.release(previous); - } + this.queue.forEach(DataBufferUtils::release); + this.queue.clear(); } @Override diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java index 98da50cd66f..b60587452ac 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java @@ -211,10 +211,9 @@ class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTests { private static void verifyContents(Path tempFile, Resource resource) { try { - byte[] tempBytes = Files.readAllBytes(tempFile); // Use FileCopyUtils since the resource might reside in a JAR instead of in the file system. byte[] resourceBytes = FileCopyUtils.copyToByteArray(resource.getInputStream()); - assertThat(tempBytes).isEqualTo(resourceBytes); + assertThat(tempFile).hasBinaryContent(resourceBytes); } catch (IOException ex) { throw new AssertionError(ex);