Introduce PartEvent

This commit introduces the PartEvent API. PartEvents are either
- FormPartEvents, representing a form field, or
- FilePartEvents, representing a file upload.

The PartEventHttpMessageReader is a HttpMessageReader that splits
multipart data into a stream of PartEvents. Form fields generate one
FormPartEvent; file uploads produce at least one FilePartEvent. The last
element that makes up a particular part will have isLast set to true.

The PartEventHttpMessageWriter is a HttpMessageWriter that writes a
Publisher<PartEvent> to a outgoing HTTP message. This writer is
particularly useful for relaying a multipart request on the server.

Closes gh-28006
This commit is contained in:
Arjen Poutsma 2022-02-10 11:04:30 +01:00
parent 081c6463e9
commit be7fa3aaa8
22 changed files with 1436 additions and 56 deletions

View File

@ -0,0 +1,176 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.multipart;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
/**
* Default implementations of {@link PartEvent} and subtypes.
*
* @author Arjen Poutsma
* @since 6.0
*/
abstract class DefaultPartEvents {
public static FormPartEvent form(HttpHeaders headers) {
Assert.notNull(headers, "Headers must not be null");
return new DefaultFormFieldPartEvent(headers);
}
public static FormPartEvent form(HttpHeaders headers, String value) {
Assert.notNull(headers, "Headers must not be null");
Assert.notNull(value, "Value must not be null");
return new DefaultFormFieldPartEvent(headers, value);
}
public static FilePartEvent file(HttpHeaders headers, DataBuffer dataBuffer, boolean isLast) {
Assert.notNull(headers, "Headers must not be null");
Assert.notNull(dataBuffer, "DataBuffer must not be null");
return new DefaultFilePartEvent(headers, dataBuffer, isLast);
}
public static FilePartEvent file(HttpHeaders headers) {
Assert.notNull(headers, "Headers must not be null");
return new DefaultFilePartEvent(headers);
}
public static PartEvent create(HttpHeaders headers, DataBuffer dataBuffer, boolean isLast) {
Assert.notNull(headers, "Headers must not be null");
Assert.notNull(dataBuffer, "DataBuffer must not be null");
if (headers.getContentDisposition().getFilename() != null) {
return file(headers, dataBuffer, isLast);
}
else {
return new DefaultPartEvent(headers, dataBuffer, isLast);
}
}
public static PartEvent create(HttpHeaders headers) {
Assert.notNull(headers, "Headers must not be null");
if (headers.getContentDisposition().getFilename() != null) {
return file(headers);
}
else {
return new DefaultPartEvent(headers);
}
}
private static abstract class AbstractPartEvent implements PartEvent {
private final HttpHeaders headers;
protected AbstractPartEvent(HttpHeaders headers) {
this.headers = HttpHeaders.readOnlyHttpHeaders(headers);
}
@Override
public HttpHeaders headers() {
return this.headers;
}
}
/**
* Default implementation of {@link PartEvent}.
*/
private static class DefaultPartEvent extends AbstractPartEvent {
private static final DataBuffer EMPTY = DefaultDataBufferFactory.sharedInstance.allocateBuffer(0);
private final DataBuffer content;
private final boolean last;
public DefaultPartEvent(HttpHeaders headers) {
this(headers, EMPTY, true);
}
public DefaultPartEvent(HttpHeaders headers, DataBuffer content, boolean last) {
super(headers);
this.content = content;
this.last = last;
}
@Override
public DataBuffer content() {
return this.content;
}
@Override
public boolean isLast() {
return this.last;
}
}
/**
* Default implementation of {@link FormPartEvent}.
*/
private static final class DefaultFormFieldPartEvent extends AbstractPartEvent implements FormPartEvent {
private static final String EMPTY = "";
private final String value;
public DefaultFormFieldPartEvent(HttpHeaders headers) {
this(headers, EMPTY);
}
public DefaultFormFieldPartEvent(HttpHeaders headers, String value) {
super(headers);
this.value = value;
}
@Override
public String value() {
return this.value;
}
@Override
public DataBuffer content() {
byte[] bytes = this.value.getBytes(MultipartUtils.charset(headers()));
return DefaultDataBufferFactory.sharedInstance.wrap(bytes);
}
@Override
public boolean isLast() {
return true;
}
}
/**
* Default implementation of {@link FilePartEvent}.
*/
private static class DefaultFilePartEvent extends DefaultPartEvent implements FilePartEvent {
public DefaultFilePartEvent(HttpHeaders headers) {
super(headers);
}
public DefaultFilePartEvent(HttpHeaders headers, DataBuffer content, boolean last) {
super(headers, content, last);
}
}
}

View File

@ -32,7 +32,6 @@ import reactor.core.scheduler.Schedulers;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.http.HttpMessage;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.http.codec.HttpMessageReader;
@ -218,7 +217,7 @@ public class DefaultPartHttpMessageReader extends LoggingCodecSupport implements
@Override
public Flux<Part> read(ResolvableType elementType, ReactiveHttpInputMessage message, Map<String, Object> hints) {
return Flux.defer(() -> {
byte[] boundary = boundary(message);
byte[] boundary = MultipartUtils.boundary(message, this.headersCharset);
if (boundary == null) {
return Flux.error(new DecodingException("No multipart boundary found in Content-Type: \"" +
message.getHeaders().getContentType() + "\""));
@ -231,20 +230,4 @@ public class DefaultPartHttpMessageReader extends LoggingCodecSupport implements
});
}
@Nullable
private byte[] boundary(HttpMessage message) {
MediaType contentType = message.getHeaders().getContentType();
if (contentType != null) {
String boundary = contentType.getParameter("boundary");
if (boundary != null) {
int len = boundary.length();
if (len > 2 && boundary.charAt(0) == '"' && boundary.charAt(len - 1) == '"') {
boundary = boundary.substring(1, len - 1);
}
return boundary.getBytes(this.headersCharset);
}
}
return null;
}
}

View File

@ -0,0 +1,196 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.multipart;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.function.Consumer;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.ContentDisposition;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.MediaTypeFactory;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* Represents an event triggered for a file upload. Contains the
* {@linkplain #filename() filename}, besides the {@linkplain #headers() headers}
* and {@linkplain #content() content} exposed through {@link PartEvent}.
*
* <p>On the client side, instances of this interface can be created via one
* of the overloaded {@linkplain #create(String, Path) create} methods.
*
* <p>On the server side, multipart file uploads trigger one or more
* {@code FilePartEvent}, as {@linkplain PartEvent described here}.
*
* @author Arjen Poutsma
* @since 6.0
* @see PartEvent
*/
public interface FilePartEvent extends PartEvent {
/**
* Return the original filename in the client's filesystem.
* <p><strong>Note:</strong> Please keep in mind this filename is supplied
* by the client and should not be used blindly. In addition to not using
* the directory portion, the file name could also contain characters such
* as ".." and others that can be used maliciously. It is recommended to not
* use this filename directly. Preferably generate a unique one and save
* this one somewhere for reference, if necessary.
* @return the original filename, or the empty String if no file has been chosen
* in the multipart form, or {@code null} if not defined or not available
* @see <a href="https://tools.ietf.org/html/rfc7578#section-4.2">RFC 7578, Section 4.2</a>
* @see <a href="https://owasp.org/www-community/vulnerabilities/Unrestricted_File_Upload">Unrestricted File Upload</a>
*/
default String filename() {
String filename = this.headers().getContentDisposition().getFilename();
Assert.state(filename != null, "No filename found");
return filename;
}
/**
* Creates a stream of {@code FilePartEvent} objects based on the given
* {@linkplain PartEvent#name() name} and resource.
* @param name the name of the part
* @param resource the resource
* @return a stream of events
*/
static Flux<FilePartEvent> create(String name, Resource resource) {
return create(name, resource, null);
}
/**
* Creates a stream of {@code FilePartEvent} objects based on the given
* {@linkplain PartEvent#name() name} and resource.
* @param name the name of the part
* @param resource the resource
* @param headersConsumer used to change default headers. Can be {@code null}.
* @return a stream of events
*/
static Flux<FilePartEvent> create(String name, Resource resource, @Nullable Consumer<HttpHeaders> headersConsumer) {
try {
return create(name, resource.getFile().toPath(), headersConsumer);
}
catch (IOException ex) {
return Flux.error(ex);
}
}
/**
* Creates a stream of {@code FilePartEvent} objects based on the given
* {@linkplain PartEvent#name() name} and file path.
* @param name the name of the part
* @param path the file path
* @return a stream of events
*/
static Flux<FilePartEvent> create(String name, Path path) {
return create(name, path, null);
}
/**
* Creates a stream of {@code FilePartEvent} objects based on the given
* {@linkplain PartEvent#name() name} and file path.
* @param name the name of the part
* @param path the file path
* @param headersConsumer used to change default headers. Can be {@code null}.
* @return a stream of events
*/
static Flux<FilePartEvent> create(String name, Path path, @Nullable Consumer<HttpHeaders> headersConsumer) {
Assert.hasLength(name, "Name must not be empty");
Assert.notNull(path, "Path must not be null");
return Flux.defer(() -> {
String pathName = path.toString();
MediaType contentType = MediaTypeFactory.getMediaType(pathName)
.orElse(MediaType.APPLICATION_OCTET_STREAM);
String filename = StringUtils.getFilename(pathName);
if (filename == null) {
return Flux.error(new IllegalArgumentException("Invalid file: " + pathName));
}
Flux<DataBuffer> contents = DataBufferUtils.read(path, DefaultDataBufferFactory.sharedInstance, 8192);
return create(name, filename, contentType, contents, headersConsumer);
});
}
/**
* Creates a stream of {@code FilePartEvent} objects based on the given
* {@linkplain PartEvent#name() name}, {@linkplain #filename()},
* content-type, and contents.
* @param partName the name of the part
* @param filename the filename
* @param contentType the content-type for the contents
* @param contents the contents
* @return a stream of events
*/
static Flux<FilePartEvent> create(String partName, String filename, MediaType contentType,
Flux<DataBuffer> contents) {
return create(partName, filename, contentType, contents, null);
}
/**
* Creates a stream of {@code FilePartEvent} objects based on the given
* {@linkplain PartEvent#name() name}, {@linkplain #filename()},
* content-type, and contents.
* @param partName the name of the part
* @param filename the filename
* @param contentType the content-type for the contents
* @param contents the contents
* @param headersConsumer used to change default headers. Can be {@code null}.
* @return a stream of events
*/
static Flux<FilePartEvent> create(String partName, String filename, MediaType contentType,
Flux<DataBuffer> contents, @Nullable Consumer<HttpHeaders> headersConsumer) {
Assert.hasLength(partName, "PartName must not be empty");
Assert.hasLength(filename, "Filename must not be empty");
Assert.notNull(contentType, "ContentType must not be null");
Assert.notNull(contents, "Contents must not be null");
return Flux.defer(() -> {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(contentType);
headers.setContentDisposition(ContentDisposition.formData()
.name(partName)
.filename(filename, StandardCharsets.UTF_8)
.build());
if (headersConsumer != null) {
headersConsumer.accept(headers);
}
return contents.map(content -> DefaultPartEvents.file(headers, content, false))
.concatWith(Mono.just(DefaultPartEvents.file(headers)));
});
}
}

View File

@ -0,0 +1,84 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.multipart;
import java.nio.charset.StandardCharsets;
import java.util.function.Consumer;
import reactor.core.publisher.Mono;
import org.springframework.http.ContentDisposition;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* Represents an event triggered for a form field. Contains the
* {@linkplain #value() value}, besides the {@linkplain #headers() headers}
* exposed through {@link PartEvent}.
*
* <p>Multipart form fields trigger one {@code FormPartEvent}, as
* {@linkplain PartEvent described here}.
*
* @author Arjen Poutsma
* @since 6.0
*/
public interface FormPartEvent extends PartEvent {
/**
* Return the form field value.
*/
String value();
/**
* Creates a stream with a single {@code FormPartEven} based on the given
* {@linkplain PartEvent#name() name} and {@linkplain #value() value}.
* @param name the name of the part
* @param value the form field value
* @return a single event stream
*/
static Mono<FormPartEvent> create(String name, String value) {
return create(name, value, null);
}
/**
* Creates a stream with a single {@code FormPartEven} based on the given
* {@linkplain PartEvent#name() name} and {@linkplain #value() value}.
* @param name the name of the part
* @param value the form field value
* @param headersConsumer used to change default headers. Can be {@code null}.
* @return a single event stream
*/
static Mono<FormPartEvent> create(String name, String value, @Nullable Consumer<HttpHeaders> headersConsumer) {
Assert.hasLength(name, "Name must not be empty");
Assert.notNull(value, "Value must not be null");
return Mono.fromCallable(() -> {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(new MediaType(MediaType.TEXT_PLAIN, StandardCharsets.UTF_8));
headers.setContentDisposition(ContentDisposition.formData().
name(name)
.build());
if (headersConsumer != null) {
headersConsumer.accept(headers);
}
return DefaultPartEvents.form(headers, value);
});
}
}

View File

@ -167,11 +167,11 @@ final class MultipartParser extends BaseSubscriber<DataBuffer> {
this.sink.next(new HeadersToken(headers));
}
void emitBody(DataBuffer buffer) {
void emitBody(DataBuffer buffer, boolean last) {
if (logger.isTraceEnabled()) {
logger.trace("Emitting body: " + buffer);
}
this.sink.next(new BodyToken(buffer));
this.sink.next(new BodyToken(buffer, last));
}
void emitError(Throwable t) {
@ -202,6 +202,9 @@ final class MultipartParser extends BaseSubscriber<DataBuffer> {
public abstract HttpHeaders headers();
public abstract DataBuffer buffer();
public abstract boolean isLast();
}
@ -225,6 +228,11 @@ final class MultipartParser extends BaseSubscriber<DataBuffer> {
public DataBuffer buffer() {
throw new IllegalStateException();
}
@Override
public boolean isLast() {
return false;
}
}
@ -235,8 +243,12 @@ final class MultipartParser extends BaseSubscriber<DataBuffer> {
private final DataBuffer buffer;
public BodyToken(DataBuffer buffer) {
private final boolean last;
public BodyToken(DataBuffer buffer, boolean last) {
this.buffer = buffer;
this.last = last;
}
@Override
@ -248,6 +260,11 @@ final class MultipartParser extends BaseSubscriber<DataBuffer> {
public DataBuffer buffer() {
return this.buffer;
}
@Override
public boolean isLast() {
return this.last;
}
}
@ -572,11 +589,15 @@ final class MultipartParser extends BaseSubscriber<DataBuffer> {
len += previous.readableByteCount();
}
emit.forEach(MultipartParser.this::emitBody);
emit.forEach(buffer -> MultipartParser.this.emitBody(buffer, false));
}
private void flush() {
this.queue.forEach(MultipartParser.this::emitBody);
for (Iterator<DataBuffer> iterator = this.queue.iterator(); iterator.hasNext(); ) {
DataBuffer buffer = iterator.next();
boolean last = !iterator.hasNext();
MultipartParser.this.emitBody(buffer, last);
}
this.queue.clear();
}

View File

@ -23,7 +23,9 @@ import java.nio.charset.StandardCharsets;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMessage;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
/**
* Various static utility methods for dealing with multipart parsing.
@ -47,6 +49,23 @@ abstract class MultipartUtils {
return StandardCharsets.UTF_8;
}
@Nullable
public static byte[] boundary(HttpMessage message, Charset headersCharset) {
MediaType contentType = message.getHeaders().getContentType();
if (contentType != null) {
String boundary = contentType.getParameter("boundary");
if (boundary != null) {
int len = boundary.length();
if (len > 2 && boundary.charAt(0) == '"' && boundary.charAt(len - 1) == '"') {
boundary = boundary.substring(1, len - 1);
}
return boundary.getBytes(headersCharset);
}
}
return null;
}
/**
* Concatenates the given array of byte arrays.
*/
@ -91,4 +110,9 @@ abstract class MultipartUtils {
}
}
public static boolean isFormField(HttpHeaders headers) {
MediaType contentType = headers.getContentType();
return (contentType == null || MediaType.TEXT_PLAIN.equalsTypeAndSubtype(contentType))
&& headers.getContentDisposition().getFilename() == null;
}
}

View File

@ -0,0 +1,150 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.multipart;
import java.nio.file.Path;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
/**
* Represents an event for a "multipart/form-data" request.
* Can be a {@link FormPartEvent} or a {@link FilePartEvent}.
*
* <h2>Server Side</h2>
*
* Each part in a multipart HTTP message produces at least one
* {@code PartEvent} containing both {@link #headers() headers} and a
* {@linkplain PartEvent#content() buffer} with content of the part.
* <ul>
* <li>Form field will produce a <em>single</em> {@link FormPartEvent},
* containing the {@linkplain FormPartEvent#value() value} of the field.</li>
* <li>File uploads will produce <em>one or more</em> {@link FilePartEvent}s,
* containing the {@linkplain FilePartEvent#filename() filename} used when
* uploading. If the file is large enough to be split across multiple buffers,
* the first {@code FilePartEvent} will be followed by subsequent events.</li>
* </ul>
* The final {@code PartEvent} for a particular part will have
* {@link #isLast()} set to {@code true}, and can be followed by
* additional events belonging to subsequent parts.
* The {@code isLast()} property is suitable as a predicate for the
* {@link Flux#windowUntil(Predicate)} operator, in order to split events from
* all parts into windows that each belong to a single part.
* From that, the {@link Flux#switchOnFirst(BiFunction)} operator allows you to
* see whether you are handling a form field or file upload.
* For example:
*
* <pre class=code>
* Flux&lt;PartEvent&gt; allPartsEvents = ... // obtained via @RequestPayload or request.bodyToFlux(PartEvent.class)
* allPartsEvents.windowUntil(PartEvent::isLast)
* .concatMap(p -> p.switchOnFirst((signal, partEvents) -> {
* if (signal.hasValue()) {
* PartEvent event = signal.get();
* if (event instanceof FormPartEvent formEvent) {
* String value = formEvent.value();
* // handle form field
* }
* else if (event instanceof FilePartEvent fileEvent) {
* String filename filename = fileEvent.filename();
* Flux&lt;DataBuffer&gt; contents = partEvents.map(PartEvent::content);
* // handle file upload
* }
* else {
* return Mono.error("Unexpected event: " + event);
* }
* }
* else {
* return partEvents; // either complete or error signal
* }
* }))
* </pre>
* Received part events can also be relayed to another service by using the
* {@link org.springframework.web.reactive.function.client.WebClient WebClient}.
* See below.
*
* <p><strong>NOTE</strong> that the {@linkplain PartEvent#content() body contents}
* must be completely consumed, relayed, or released to avoid memory leaks.
*
* <h2>Client Side</h2>
* On the client side, {@code PartEvent}s can be created to represent a file upload.
* <ul>
* <li>Form fields can be created via {@link FormPartEvent#create(String, String)}.</li>
* <li>File uploads can be created via {@link FilePartEvent#create(String, Path)}.</li>
* </ul>
* The streams returned by these static methods can be concatenated via
* {@link Flux#concat(Publisher[])} to create a request for the
* {@link org.springframework.web.reactive.function.client.WebClient WebClient}:
* For instance, this sample will POST a multipart form containing a form field
* and a file.
*
* <pre class=code>
* Resource resource = ...
* Mono&lt;String&gt; result = webClient
* .post()
* .uri("https://example.com")
* .body(Flux.concat(
* FormEventPart.create("field", "field value"),
* FilePartEvent.create("file", resource)
* ), PartEvent.class)
* .retrieve()
* .bodyToMono(String.class);
* </pre>
*
* @author Arjen Poutsma
* @since 6.0
* @see FormPartEvent
* @see FilePartEvent
* @see PartEventHttpMessageReader
* @see PartEventHttpMessageWriter
*/
public interface PartEvent {
/**
* Return the name of the event, as provided through the
* {@code Content-Disposition name} parameter.
* @return the name of the part, never {@code null} or empty
*/
default String name() {
String name = headers().getContentDisposition().getName();
Assert.state(name != null, "No name available");
return name;
}
/**
* Return the headers of the part that this event belongs to.
*/
HttpHeaders headers();
/**
* Return the content of this event. The returned buffer must be consumed or
* {@linkplain org.springframework.core.io.buffer.DataBufferUtils#release(DataBuffer) released}.
*/
DataBuffer content();
/**
* Indicates whether this is the last event of a particular
* part.
*/
boolean isLast();
}

View File

@ -0,0 +1,174 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.multipart;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.LoggingCodecSupport;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* {@code HttpMessageReader} for parsing {@code "multipart/form-data"} requests
* to a stream of {@link PartEvent} elements.
*
* @author Arjen Poutsma
* @since 6.0
* @see PartEvent
*/
public class PartEventHttpMessageReader extends LoggingCodecSupport implements HttpMessageReader<PartEvent> {
private int maxInMemorySize = 256 * 1024;
private int maxHeadersSize = 10 * 1024;
private Charset headersCharset = StandardCharsets.UTF_8;
/**
* Get the {@link #setMaxInMemorySize configured} maximum in-memory size.
*/
public int getMaxInMemorySize() {
return this.maxInMemorySize;
}
/**
* Configure the maximum amount of memory allowed for form fields.
* When the limit is exceeded, form fields parts are rejected with
* {@link DataBufferLimitException}.
* <p>By default this is set to 256K.
* @param maxInMemorySize the in-memory limit in bytes; if set to -1 the entire
* contents will be stored in memory
*/
public void setMaxInMemorySize(int maxInMemorySize) {
this.maxInMemorySize = maxInMemorySize;
}
/**
* Configure the maximum amount of memory that is allowed per headers section of each part.
* Defaults to 10K.
* @param byteCount the maximum amount of memory for headers
*/
public void setMaxHeadersSize(int byteCount) {
this.maxHeadersSize = byteCount;
}
/**
* Set the character set used to decode headers.
* Defaults to UTF-8 as per RFC 7578.
* @param headersCharset the charset to use for decoding headers
* @see <a href="https://tools.ietf.org/html/rfc7578#section-5.1">RFC-7578 Section 5.1</a>
*/
public void setHeadersCharset(Charset headersCharset) {
Assert.notNull(headersCharset, "HeadersCharset must not be null");
this.headersCharset = headersCharset;
}
@Override
public List<MediaType> getReadableMediaTypes() {
return Collections.singletonList(MediaType.MULTIPART_FORM_DATA);
}
@Override
public boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType) {
return PartEvent.class.equals(elementType.toClass()) &&
(mediaType == null || MediaType.MULTIPART_FORM_DATA.isCompatibleWith(mediaType));
}
@Override
public Mono<PartEvent> readMono(ResolvableType elementType, ReactiveHttpInputMessage message,
Map<String, Object> hints) {
return Mono.error(
new UnsupportedOperationException("Cannot read multipart request body into single PartEvent"));
}
@Override
public Flux<PartEvent> read(ResolvableType elementType, ReactiveHttpInputMessage message,
Map<String, Object> hints) {
return Flux.defer(() -> {
byte[] boundary = MultipartUtils.boundary(message, this.headersCharset);
if (boundary == null) {
return Flux.error(new DecodingException("No multipart boundary found in Content-Type: \"" +
message.getHeaders().getContentType() + "\""));
}
return MultipartParser.parse(message.getBody(), boundary, this.maxHeadersSize, this.headersCharset)
.windowUntil(t -> t instanceof MultipartParser.HeadersToken, true)
.concatMap(tokens -> tokens.switchOnFirst((signal, flux) -> {
if (signal.hasValue()) {
MultipartParser.HeadersToken headersToken = (MultipartParser.HeadersToken) signal.get();
Assert.state(headersToken != null, "Signal should be headers token");
HttpHeaders headers = headersToken.headers();
Flux<MultipartParser.BodyToken> bodyTokens =
flux.filter(t -> t instanceof MultipartParser.BodyToken)
.cast(MultipartParser.BodyToken.class);
return createEvents(headers, bodyTokens);
}
else {
// complete or error signal
return flux.cast(PartEvent.class);
}
}));
});
}
private Publisher<? extends PartEvent> createEvents(HttpHeaders headers, Flux<MultipartParser.BodyToken> bodyTokens) {
if (MultipartUtils.isFormField(headers)) {
Flux<DataBuffer> contents = bodyTokens.map(MultipartParser.BodyToken::buffer);
return DataBufferUtils.join(contents, this.maxInMemorySize)
.map(content -> {
String value = content.toString(MultipartUtils.charset(headers));
DataBufferUtils.release(content);
return DefaultPartEvents.form(headers, value);
})
.switchIfEmpty(Mono.fromCallable(() -> DefaultPartEvents.form(headers)));
}
else if (headers.getContentDisposition().getFilename() != null) {
return bodyTokens
.map(body -> DefaultPartEvents.file(headers, body.buffer(), body.isLast()))
.switchIfEmpty(Mono.fromCallable(() -> DefaultPartEvents.file(headers)));
}
else {
return bodyTokens
.map(body -> DefaultPartEvents.create(headers, body.buffer(), body.isLast()))
.switchIfEmpty(Mono.fromCallable(() -> DefaultPartEvents.create(headers))); // empty body
}
}
}

View File

@ -0,0 +1,111 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.multipart;
import java.util.Collections;
import java.util.Map;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Hints;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.PooledDataBuffer;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpOutputMessage;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* {@link HttpMessageWriter} for writing {@link PartEvent} objects. Useful for
* server-side proxies, that relay multipart requests to others services.
*
* @author Arjen Poutsma
* @since 6.0
* @see PartEvent
*/
public class PartEventHttpMessageWriter extends MultipartWriterSupport implements HttpMessageWriter<PartEvent> {
public PartEventHttpMessageWriter() {
super(Collections.singletonList(MediaType.MULTIPART_FORM_DATA));
}
@Override
public boolean canWrite(ResolvableType elementType, @Nullable MediaType mediaType) {
if (PartEvent.class.isAssignableFrom(elementType.toClass())) {
if (mediaType == null) {
return true;
}
for (MediaType supportedMediaType : getWritableMediaTypes()) {
if (supportedMediaType.isCompatibleWith(mediaType)) {
return true;
}
}
}
return false;
}
@Override
public Mono<Void> write(Publisher<? extends PartEvent> partDataStream, ResolvableType elementType,
@Nullable MediaType mediaType, ReactiveHttpOutputMessage outputMessage,
Map<String, Object> hints) {
byte[] boundary = generateMultipartBoundary();
mediaType = getMultipartMediaType(mediaType, boundary);
outputMessage.getHeaders().setContentType(mediaType);
if (logger.isDebugEnabled()) {
logger.debug(Hints.getLogPrefix(hints) + "Encoding Publisher<PartEvent>");
}
Flux<DataBuffer> body = Flux.from(partDataStream)
.windowUntil(PartEvent::isLast)
.concatMap(partData -> partData.switchOnFirst((signal, flux) -> {
if (signal.hasValue()) {
PartEvent value = signal.get();
Assert.state(value != null, "Null value");
return encodePartData(boundary, outputMessage.bufferFactory(), value, flux);
}
else {
return flux.cast(DataBuffer.class);
}
}))
.concatWith(generateLastLine(boundary, outputMessage.bufferFactory()))
.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
if (logger.isDebugEnabled()) {
body = body.doOnNext(buffer -> Hints.touchDataBuffer(buffer, hints, logger));
}
return outputMessage.writeWith(body);
}
private Flux<DataBuffer> encodePartData(byte[] boundary, DataBufferFactory bufferFactory, PartEvent first, Flux<? extends PartEvent> flux) {
return Flux.concat(
generateBoundaryLine(boundary, bufferFactory),
generatePartHeaders(first.headers(), bufferFactory),
flux.map(PartEvent::content),
generateNewLine(bufferFactory));
}
}

View File

@ -50,7 +50,6 @@ import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.util.FastByteArrayOutputStream;
/**
@ -144,7 +143,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
}
private void newPart(State currentState, HttpHeaders headers) {
if (isFormField(headers)) {
if (MultipartUtils.isFormField(headers)) {
changeStateInternal(new FormFieldState(headers));
requestToken();
}
@ -245,12 +244,6 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
}
}
private static boolean isFormField(HttpHeaders headers) {
MediaType contentType = headers.getContentType();
return (contentType == null || MediaType.TEXT_PLAIN.equalsTypeAndSubtype(contentType))
&& headers.getContentDisposition().getFilename() == null;
}
/**
* Represents the internal state of the {@link PartGenerator} for
* creating a single {@link Part}.
@ -259,7 +252,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
* {@link #newPart(State, HttpHeaders)}.
* The following rules determine which state the creator will have:
* <ol>
* <li>If the part is a {@linkplain #isFormField(HttpHeaders) form field},
* <li>If the part is a {@linkplain MultipartUtils#isFormField(HttpHeaders) form field},
* the creator will be in the {@link FormFieldState}.</li>
* <li>If {@linkplain #streaming} is enabled, the creator will be in the
* {@link StreamingState}.</li>
@ -328,7 +321,7 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
/**
* The creator state when a {@linkplain #isFormField(HttpHeaders) form field} is received.
* The creator state when a {@linkplain MultipartUtils#isFormField(HttpHeaders) form field} is received.
* Stores all body buffers in memory (up until {@link #maxInMemorySize}).
*/
private final class FormFieldState implements State {

View File

@ -58,6 +58,7 @@ import org.springframework.http.codec.json.KotlinSerializationJsonEncoder;
import org.springframework.http.codec.multipart.DefaultPartHttpMessageReader;
import org.springframework.http.codec.multipart.MultipartHttpMessageReader;
import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
import org.springframework.http.codec.multipart.PartEventHttpMessageReader;
import org.springframework.http.codec.protobuf.ProtobufDecoder;
import org.springframework.http.codec.protobuf.ProtobufEncoder;
import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter;
@ -416,6 +417,9 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs, CodecConfigure
if (codec instanceof DefaultPartHttpMessageReader) {
((DefaultPartHttpMessageReader) codec).setMaxInMemorySize(size);
}
if (codec instanceof PartEventHttpMessageReader) {
((PartEventHttpMessageReader) codec).setMaxInMemorySize(size);
}
}
Boolean enable = this.enableLoggingRequestDetails;
@ -429,6 +433,9 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs, CodecConfigure
if (codec instanceof DefaultPartHttpMessageReader) {
((DefaultPartHttpMessageReader) codec).setEnableLoggingRequestDetails(enable);
}
if (codec instanceof PartEventHttpMessageReader) {
((PartEventHttpMessageReader) codec).setEnableLoggingRequestDetails(enable);
}
if (codec instanceof FormHttpMessageWriter) {
((FormHttpMessageWriter) codec).setEnableLoggingRequestDetails(enable);
}

View File

@ -30,6 +30,7 @@ import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.http.codec.ServerSentEventHttpMessageReader;
import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
import org.springframework.http.codec.multipart.PartEventHttpMessageWriter;
import org.springframework.lang.Nullable;
/**
@ -100,6 +101,7 @@ class ClientDefaultCodecsImpl extends BaseDefaultCodecs implements ClientCodecCo
@Override
protected void extendTypedWriters(List<HttpMessageWriter<?>> typedWriters) {
addCodec(typedWriters, new MultipartHttpMessageWriter(getPartWriters(), new FormHttpMessageWriter()));
addCodec(typedWriters, new PartEventHttpMessageWriter());
}
private List<HttpMessageWriter<?>> getPartWriters() {

View File

@ -25,6 +25,7 @@ import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.codec.ServerSentEventHttpMessageWriter;
import org.springframework.http.codec.multipart.DefaultPartHttpMessageReader;
import org.springframework.http.codec.multipart.MultipartHttpMessageReader;
import org.springframework.http.codec.multipart.PartEventHttpMessageReader;
import org.springframework.http.codec.multipart.PartHttpMessageWriter;
import org.springframework.lang.Nullable;
@ -69,11 +70,13 @@ class ServerDefaultCodecsImpl extends BaseDefaultCodecs implements ServerCodecCo
protected void extendTypedReaders(List<HttpMessageReader<?>> typedReaders) {
if (this.multipartReader != null) {
addCodec(typedReaders, this.multipartReader);
return;
}
DefaultPartHttpMessageReader partReader = new DefaultPartHttpMessageReader();
addCodec(typedReaders, partReader);
addCodec(typedReaders, new MultipartHttpMessageReader(partReader));
else {
DefaultPartHttpMessageReader partReader = new DefaultPartHttpMessageReader();
addCodec(typedReaders, partReader);
addCodec(typedReaders, new MultipartHttpMessageReader(partReader));
}
addCodec(typedReaders, new PartEventHttpMessageReader());
}
@Override

View File

@ -0,0 +1,317 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.multipart;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.function.Consumer;
import io.netty.buffer.PooledByteBufAllocator;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.ContentDisposition;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.entry;
import static org.springframework.core.ResolvableType.forClass;
/**
* @author Arjen Poutsma
*/
class PartEventHttpMessageReaderTests {
private static final int BUFFER_SIZE = 64;
private static final DataBufferFactory bufferFactory = new NettyDataBufferFactory(new PooledByteBufAllocator());
private static final MediaType TEXT_PLAIN_ASCII = new MediaType("text", "plain", StandardCharsets.US_ASCII);
private final PartEventHttpMessageReader reader = new PartEventHttpMessageReader();
@Test
public void canRead() {
assertThat(this.reader.canRead(forClass(PartEvent.class), MediaType.MULTIPART_FORM_DATA)).isTrue();
assertThat(this.reader.canRead(forClass(PartEvent.class), null)).isTrue();
}
@Test
public void simple() {
MockServerHttpRequest request = createRequest(
new ClassPathResource("simple.multipart", getClass()), "simple-boundary");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.assertNext(form(headers -> assertThat(headers).isEmpty(), "This is implicitly typed plain ASCII text.\r\nIt does NOT end with a linebreak."))
.assertNext(form(headers -> assertThat(headers.getContentType()).isEqualTo(TEXT_PLAIN_ASCII),
"This is explicitly typed plain ASCII text.\r\nIt DOES end with a linebreak.\r\n"))
.verifyComplete();
}
@Test
public void noHeaders() {
MockServerHttpRequest request = createRequest(
new ClassPathResource("no-header.multipart", getClass()), "boundary");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.assertNext(data(headers -> assertThat(headers).isEmpty(), bodyText("a"), true))
.verifyComplete();
}
@Test
public void noEndBoundary() {
MockServerHttpRequest request = createRequest(
new ClassPathResource("no-end-boundary.multipart", getClass()), "boundary");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.expectError(DecodingException.class)
.verify();
}
@Test
public void garbage() {
MockServerHttpRequest request = createRequest(
new ClassPathResource("garbage-1.multipart", getClass()), "boundary");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.expectError(DecodingException.class)
.verify();
}
@Test
public void noEndHeader() {
MockServerHttpRequest request = createRequest(
new ClassPathResource("no-end-header.multipart", getClass()), "boundary");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.expectError(DecodingException.class)
.verify();
}
@Test
public void noEndBody() {
MockServerHttpRequest request = createRequest(
new ClassPathResource("no-end-body.multipart", getClass()), "boundary");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.expectError(DecodingException.class)
.verify();
}
@Test
public void noBody() {
MockServerHttpRequest request = createRequest(
new ClassPathResource("no-body.multipart", getClass()), "boundary");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.assertNext(form(headers -> assertThat(headers).contains(entry("Part", List.of("1"))), ""))
.assertNext(data(headers -> assertThat(headers).contains(entry("Part", List.of("2"))), bodyText("a"), true))
.verifyComplete();
}
@Test
public void cancel() {
MockServerHttpRequest request = createRequest(
new ClassPathResource("simple.multipart", getClass()), "simple-boundary");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result, 3)
.assertNext(form(headers -> assertThat(headers).isEmpty(),
"This is implicitly typed plain ASCII text.\r\nIt does NOT end with a linebreak."))
.thenCancel()
.verify();
}
@Test
public void firefox() {
MockServerHttpRequest request = createRequest(new ClassPathResource("firefox.multipart", getClass()),
"---------------------------18399284482060392383840973206");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.assertNext(data(headersFormField("text1"), bodyText("a"), true))
.assertNext(data(headersFormField("text2"), bodyText("b"), true))
.assertNext(data(headersFile("file1", "a.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file1", "a.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file1", "a.txt"), DataBufferUtils::release, true))
.assertNext(data(headersFile("file2", "a.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "a.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "a.txt"), DataBufferUtils::release, true))
.assertNext(data(headersFile("file2", "b.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "b.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "b.txt"), DataBufferUtils::release, true))
.verifyComplete();
}
@Test
public void chrome() {
MockServerHttpRequest request = createRequest(new ClassPathResource("chrome.multipart", getClass()),
"----WebKitFormBoundaryEveBLvRT65n21fwU");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.assertNext(data(headersFormField("text1"), bodyText("a"), true))
.assertNext(data(headersFormField("text2"), bodyText("b"), true))
.assertNext(data(headersFile("file1", "a.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file1", "a.txt"), DataBufferUtils::release, true))
.assertNext(data(headersFile("file2", "a.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "a.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "a.txt"), DataBufferUtils::release, true))
.assertNext(data(headersFile("file2", "b.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "b.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "b.txt"), DataBufferUtils::release, true))
.verifyComplete();
}
@Test
public void safari() {
MockServerHttpRequest request = createRequest(new ClassPathResource("safari.multipart", getClass()),
"----WebKitFormBoundaryG8fJ50opQOML0oGD");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.assertNext(data(headersFormField("text1"), bodyText("a"), true))
.assertNext(data(headersFormField("text2"), bodyText("b"), true))
.assertNext(data(headersFile("file1", "a.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file1", "a.txt"), DataBufferUtils::release, true))
.assertNext(data(headersFile("file2", "a.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "a.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "a.txt"), DataBufferUtils::release, true))
.assertNext(data(headersFile("file2", "b.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "b.txt"), DataBufferUtils::release, false))
.assertNext(data(headersFile("file2", "b.txt"), DataBufferUtils::release, true))
.verifyComplete();
}
@Test
public void utf8Headers() {
MockServerHttpRequest request = createRequest(
new ClassPathResource("utf8.multipart", getClass()), "\"simple-boundary\"");
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.assertNext(data(headers -> assertThat(headers).containsEntry("Føø", List.of("Bår")),
bodyText("This is plain ASCII text."), true))
.verifyComplete();
}
@Test
public void exceedHeaderLimit() {
Flux<DataBuffer> body = DataBufferUtils
.readByteChannel((new ClassPathResource("files.multipart", getClass()))::readableChannel, bufferFactory,
282);
MediaType contentType = new MediaType("multipart", "form-data",
singletonMap("boundary", "----WebKitFormBoundaryG8fJ50opQOML0oGD"));
MockServerHttpRequest request = MockServerHttpRequest.post("/")
.contentType(contentType)
.body(body);
this.reader.setMaxHeadersSize(230);
Flux<PartEvent> result = this.reader.read(forClass(PartEvent.class), request, emptyMap());
StepVerifier.create(result)
.assertNext(data(headersFile("file2", "a.txt"), DataBufferUtils::release, true))
.assertNext(data(headersFile("file2", "b.txt"), DataBufferUtils::release, true))
.verifyComplete();
}
private MockServerHttpRequest createRequest(Resource resource, String boundary) {
Flux<DataBuffer> body = DataBufferUtils
.readByteChannel(resource::readableChannel, bufferFactory, BUFFER_SIZE);
MediaType contentType = new MediaType("multipart", "form-data", singletonMap("boundary", boundary));
return MockServerHttpRequest.post("/")
.contentType(contentType)
.body(body);
}
private static Consumer<PartEvent> form(Consumer<HttpHeaders> headersConsumer, String value) {
return data -> {
headersConsumer.accept(data.headers());
String actual = data.content().toString(UTF_8);
assertThat(actual).isEqualTo(value);
assertThat(data.isLast()).isTrue();
};
}
private static Consumer<PartEvent> data(Consumer<HttpHeaders> headersConsumer, Consumer<DataBuffer> bufferConsumer, boolean isLast) {
return data -> {
headersConsumer.accept(data.headers());
bufferConsumer.accept(data.content());
assertThat(data.isLast()).isEqualTo(isLast);
};
}
private static Consumer<HttpHeaders> headersFormField(String expectedName) {
return headers -> {
ContentDisposition cd = headers.getContentDisposition();
assertThat(cd.isFormData()).isTrue();
assertThat(cd.getName()).isEqualTo(expectedName);
};
}
private static Consumer<HttpHeaders> headersFile(String expectedName, String expectedFilename) {
return headers -> {
ContentDisposition cd = headers.getContentDisposition();
assertThat(cd.isFormData()).isTrue();
assertThat(cd.getName()).isEqualTo(expectedName);
assertThat(cd.getFilename()).isEqualTo(expectedFilename);
};
}
private static Consumer<DataBuffer> bodyText(String expected) {
return buffer -> {
String s = buffer.toString(UTF_8);
DataBufferUtils.release(buffer);
assertThat(s).isEqualTo(expected);
};
}
}

View File

@ -61,6 +61,7 @@ import org.springframework.http.codec.json.Jackson2SmileEncoder;
import org.springframework.http.codec.json.KotlinSerializationJsonDecoder;
import org.springframework.http.codec.json.KotlinSerializationJsonEncoder;
import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
import org.springframework.http.codec.multipart.PartEventHttpMessageWriter;
import org.springframework.http.codec.protobuf.ProtobufDecoder;
import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter;
import org.springframework.http.codec.xml.Jaxb2XmlDecoder;
@ -106,7 +107,7 @@ public class ClientCodecConfigurerTests {
@Test
public void defaultWriters() {
List<HttpMessageWriter<?>> writers = this.configurer.getWriters();
assertThat(writers.size()).isEqualTo(13);
assertThat(writers.size()).isEqualTo(14);
assertThat(getNextEncoder(writers).getClass()).isEqualTo(ByteArrayEncoder.class);
assertThat(getNextEncoder(writers).getClass()).isEqualTo(ByteBufferEncoder.class);
assertThat(getNextEncoder(writers).getClass()).isEqualTo(DataBufferEncoder.class);
@ -115,6 +116,7 @@ public class ClientCodecConfigurerTests {
assertStringEncoder(getNextEncoder(writers), true);
assertThat(writers.get(index.getAndIncrement()).getClass()).isEqualTo(ProtobufHttpMessageWriter.class);
assertThat(writers.get(this.index.getAndIncrement()).getClass()).isEqualTo(MultipartHttpMessageWriter.class);
assertThat(writers.get(this.index.getAndIncrement()).getClass()).isEqualTo(PartEventHttpMessageWriter.class);
assertThat(getNextEncoder(writers).getClass()).isEqualTo(KotlinSerializationJsonEncoder.class);
assertThat(getNextEncoder(writers).getClass()).isEqualTo(Jackson2JsonEncoder.class);
assertThat(getNextEncoder(writers).getClass()).isEqualTo(Jackson2SmileEncoder.class);

View File

@ -60,6 +60,7 @@ import org.springframework.http.codec.json.KotlinSerializationJsonDecoder;
import org.springframework.http.codec.json.KotlinSerializationJsonEncoder;
import org.springframework.http.codec.multipart.DefaultPartHttpMessageReader;
import org.springframework.http.codec.multipart.MultipartHttpMessageReader;
import org.springframework.http.codec.multipart.PartEventHttpMessageReader;
import org.springframework.http.codec.multipart.PartHttpMessageWriter;
import org.springframework.http.codec.protobuf.ProtobufDecoder;
import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter;
@ -85,7 +86,7 @@ public class ServerCodecConfigurerTests {
@Test
public void defaultReaders() {
List<HttpMessageReader<?>> readers = this.configurer.getReaders();
assertThat(readers.size()).isEqualTo(15);
assertThat(readers.size()).isEqualTo(16);
assertThat(getNextDecoder(readers).getClass()).isEqualTo(ByteArrayDecoder.class);
assertThat(getNextDecoder(readers).getClass()).isEqualTo(ByteBufferDecoder.class);
assertThat(getNextDecoder(readers).getClass()).isEqualTo(DataBufferDecoder.class);
@ -96,6 +97,7 @@ public class ServerCodecConfigurerTests {
assertThat(readers.get(this.index.getAndIncrement()).getClass()).isEqualTo(FormHttpMessageReader.class);
assertThat(readers.get(this.index.getAndIncrement()).getClass()).isEqualTo(DefaultPartHttpMessageReader.class);
assertThat(readers.get(this.index.getAndIncrement()).getClass()).isEqualTo(MultipartHttpMessageReader.class);
assertThat(readers.get(this.index.getAndIncrement()).getClass()).isEqualTo(PartEventHttpMessageReader.class);
assertThat(getNextDecoder(readers).getClass()).isEqualTo(KotlinSerializationJsonDecoder.class);
assertThat(getNextDecoder(readers).getClass()).isEqualTo(Jackson2JsonDecoder.class);
assertThat(getNextDecoder(readers).getClass()).isEqualTo(Jackson2SmileDecoder.class);
@ -159,6 +161,7 @@ public class ServerCodecConfigurerTests {
MultipartHttpMessageReader multipartReader = (MultipartHttpMessageReader) nextReader(readers);
DefaultPartHttpMessageReader reader = (DefaultPartHttpMessageReader) multipartReader.getPartReader();
assertThat((reader).getMaxInMemorySize()).isEqualTo(size);
assertThat(((PartEventHttpMessageReader) nextReader(readers)).getMaxInMemorySize()).isEqualTo(size);
assertThat(((KotlinSerializationJsonDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);
assertThat(((Jackson2JsonDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);

View File

@ -0,0 +1,9 @@
--boundary
Part: 1
--boundary
Part: 2
a
--boundary--

View File

@ -112,7 +112,7 @@ public class DelegatingWebFluxConfigurationTests {
boolean condition = initializer.getValidator() instanceof LocalValidatorFactoryBean;
assertThat(condition).isTrue();
assertThat(initializer.getConversionService()).isSameAs(formatterRegistry.getValue());
assertThat(codecsConfigurer.getValue().getReaders().size()).isEqualTo(14);
assertThat(codecsConfigurer.getValue().getReaders().size()).isEqualTo(15);
}
@Test

View File

@ -151,7 +151,7 @@ public class WebFluxConfigurationSupportTests {
assertThat(adapter).isNotNull();
List<HttpMessageReader<?>> readers = adapter.getMessageReaders();
assertThat(readers.size()).isEqualTo(14);
assertThat(readers.size()).isEqualTo(15);
ResolvableType multiValueMapType = forClassWithGenerics(MultiValueMap.class, String.class, String.class);

View File

@ -17,10 +17,12 @@
package org.springframework.web.reactive.function;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Disabled;
@ -31,15 +33,16 @@ import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.HttpEntity;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.MultipartBodyBuilder;
import org.springframework.http.codec.multipart.FilePart;
import org.springframework.http.codec.multipart.FilePartEvent;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.codec.multipart.FormPartEvent;
import org.springframework.http.codec.multipart.Part;
import org.springframework.http.codec.multipart.PartEvent;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.server.AbstractRouterFunctionIntegrationTests;
import org.springframework.web.reactive.function.server.RouterFunction;
@ -60,7 +63,7 @@ class MultipartIntegrationTests extends AbstractRouterFunctionIntegrationTests {
private final WebClient webClient = WebClient.create();
private ClassPathResource resource = new ClassPathResource("org/springframework/http/codec/multipart/foo.txt");
private final ClassPathResource resource = new ClassPathResource("foo.txt", getClass());
@ParameterizedHttpServerTest
@ -70,7 +73,7 @@ class MultipartIntegrationTests extends AbstractRouterFunctionIntegrationTests {
Mono<ResponseEntity<Void>> result = webClient
.post()
.uri("http://localhost:" + this.port + "/multipartData")
.bodyValue(generateBody())
.body(generateBody(), PartEvent.class)
.retrieve()
.toEntity(Void.class);
@ -88,7 +91,7 @@ class MultipartIntegrationTests extends AbstractRouterFunctionIntegrationTests {
Mono<ResponseEntity<Void>> result = webClient
.post()
.uri("http://localhost:" + this.port + "/parts")
.bodyValue(generateBody())
.body(generateBody(), PartEvent.class)
.retrieve()
.toEntity(Void.class);
@ -120,7 +123,7 @@ class MultipartIntegrationTests extends AbstractRouterFunctionIntegrationTests {
Mono<String> result = webClient
.post()
.uri("http://localhost:" + this.port + "/transferTo")
.bodyValue(generateBody())
.body(generateBody(), PartEvent.class)
.retrieve()
.bodyToMono(String.class);
@ -140,11 +143,48 @@ class MultipartIntegrationTests extends AbstractRouterFunctionIntegrationTests {
.verify(Duration.ofSeconds(5));
}
private MultiValueMap<String, HttpEntity<?>> generateBody() {
MultipartBodyBuilder builder = new MultipartBodyBuilder();
builder.part("fooPart", resource);
builder.part("barPart", "bar");
return builder.build();
@ParameterizedHttpServerTest
void partData(HttpServer httpServer) throws Exception {
startServer(httpServer);
Mono<ResponseEntity<Void>> result = webClient
.post()
.uri("http://localhost:" + this.port + "/partData")
.body(generateBody(), PartEvent.class)
.retrieve()
.toEntity(Void.class);
StepVerifier
.create(result)
.consumeNextWith(entity -> assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK))
.expectComplete()
.verify(Duration.ofSeconds(5));
}
@ParameterizedHttpServerTest
void proxy(HttpServer httpServer) throws Exception {
startServer(httpServer);
Mono<ResponseEntity<Void>> result = webClient
.post()
.uri("http://localhost:" + this.port + "/proxy")
.body(generateBody(), PartEvent.class)
.retrieve()
.toEntity(Void.class);
StepVerifier
.create(result)
.consumeNextWith(entity -> assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK))
.expectComplete()
.verify(Duration.ofSeconds(5));
}
private Flux<PartEvent> generateBody() {
return Flux.concat(
FilePartEvent.create("fooPart", this.resource),
FormPartEvent.create("barPart", "bar")
);
}
@Override
@ -154,6 +194,8 @@ class MultipartIntegrationTests extends AbstractRouterFunctionIntegrationTests {
.POST("/multipartData", multipartHandler::multipartData)
.POST("/parts", multipartHandler::parts)
.POST("/transferTo", multipartHandler::transferTo)
.POST("/partData", multipartHandler::partData)
.POST("/proxy", multipartHandler::proxy)
.build();
}
@ -207,6 +249,44 @@ class MultipartIntegrationTests extends AbstractRouterFunctionIntegrationTests {
.then(ServerResponse.ok().bodyValue(tempFile.toString()))));
}
public Mono<ServerResponse> partData(ServerRequest request) {
return request.bodyToFlux(PartEvent.class)
.bufferUntil(PartEvent::isLast)
.collectList()
.flatMap((List<List<PartEvent>> data) -> {
assertThat(data).hasSize(2);
List<PartEvent> fileData = data.get(0);
assertThat(fileData).hasSize(1);
assertThat(fileData.get(0)).isInstanceOf(FilePartEvent.class);
FilePartEvent filePartEvent = (FilePartEvent) fileData.get(0);
assertThat(filePartEvent.name()).isEqualTo("fooPart");
assertThat(filePartEvent.filename()).isEqualTo("foo.txt");
DataBufferUtils.release(filePartEvent.content());
List<PartEvent> fieldData = data.get(1);
assertThat(fieldData).hasSize(1);
assertThat(fieldData.get(0)).isInstanceOf(FormPartEvent.class);
FormPartEvent formPartEvent = (FormPartEvent) fieldData.get(0);
assertThat(formPartEvent.name()).isEqualTo("barPart");
assertThat(formPartEvent.content().toString(StandardCharsets.UTF_8)).isEqualTo("bar");
DataBufferUtils.release(filePartEvent.content());
return ServerResponse.ok().build();
});
}
public Mono<ServerResponse> proxy(ServerRequest request) {
return Mono.defer(() -> {
WebClient client = WebClient.create("http://localhost:" + request.uri().getPort() + "/multipartData");
return client.post()
.body(request.bodyToFlux(PartEvent.class), PartEvent.class)
.retrieve()
.toEntity(Void.class)
.flatMap(response -> ServerResponse.ok().build());
});
}
private Mono<Path> createTempFile() {
return Mono.defer(() -> {
try {

View File

@ -35,6 +35,8 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.ContentDisposition;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
@ -43,6 +45,7 @@ import org.springframework.http.codec.multipart.FilePart;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.codec.multipart.MultipartHttpMessageReader;
import org.springframework.http.codec.multipart.Part;
import org.springframework.http.codec.multipart.PartEvent;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.MultiValueMap;
@ -200,6 +203,22 @@ class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTests {
.verifyComplete();
}
@ParameterizedHttpServerTest
void partData(HttpServer httpServer) throws Exception {
startServer(httpServer);
Mono<String> result = webClient
.post()
.uri("/partData")
.bodyValue(generateBody())
.retrieve()
.bodyToMono(String.class);
StepVerifier.create(result)
.consumeNextWith(body -> assertThat(body).isEqualTo("fieldPart,foo.txt:fileParts,logo.png:fileParts,jsonPart,"))
.verifyComplete();
}
private MultiValueMap<String, HttpEntity<?>> generateBody() {
MultipartBodyBuilder builder = new MultipartBodyBuilder();
builder.part("fieldPart", "fieldValue");
@ -277,7 +296,7 @@ class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTests {
}
private Mono<Path> createTempFile(String suffix) {
return Mono.defer(() -> {
return Mono.defer(() -> {
try {
return Mono.just(Files.createTempFile("MultipartIntegrationTests", suffix));
}
@ -285,13 +304,38 @@ class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTests {
return Mono.error(ex);
}
})
.subscribeOn(Schedulers.boundedElastic());
}
.subscribeOn(Schedulers.boundedElastic());
}
@PostMapping("/modelAttribute")
String modelAttribute(@ModelAttribute FormBean formBean) {
return formBean.toString();
}
@PostMapping("/partData")
Flux<String> tokens(@RequestBody Flux<PartEvent> partData) {
return partData.map(data -> {
if (data.isLast()) {
ContentDisposition cd = data.headers().getContentDisposition();
StringBuilder sb = new StringBuilder();
if (cd.getFilename() != null) {
sb.append(cd.getFilename())
.append(':')
.append(cd.getName());
}
else if (cd.getName() != null) {
sb.append(cd.getName());
}
sb.append(',');
DataBufferUtils.release(data.content());
return sb.toString();
}
else {
return "";
}
});
}
}
private static String partMapDescription(MultiValueMap<String, Part> partsMap) {