Support data binding for multipart requests in WebFlux

Issue: SPR-14546
This commit is contained in:
Rossen Stoyanchev 2017-05-03 18:46:00 -04:00
parent b5089ac092
commit fc7bededd0
12 changed files with 378 additions and 155 deletions

View File

@ -0,0 +1,45 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.multipart;
import java.io.File;
import reactor.core.publisher.Mono;
/**
* Specialization of {@link Part} for a file upload.
*
* @author Rossen Stoyanchev
* @since 5.0
*/
public interface FilePart extends Part {
/**
* Return the name of the file selected by the user in a browser form.
*/
String getFilename();
/**
* Transfer the file in this part to the given file destination.
* @param dest the target file
* @return completion {@code Mono} with the result of the file transfer,
* possibly {@link IllegalStateException} if the part isn't a file
*/
Mono<Void> transferTo(File dest);
}

View File

@ -0,0 +1,32 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.codec.multipart;
/**
* Specialization of {@link Part} for a form field.
*
* @author Rossen Stoyanchev
* @since 5.0
*/
public interface FormFieldPart extends Part {
/**
* Return the form field value.
*/
String getValue();
}

View File

@ -16,11 +16,7 @@
package org.springframework.http.codec.multipart;
import java.io.File;
import java.util.Optional;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
@ -29,9 +25,10 @@ import org.springframework.http.HttpHeaders;
* Representation for a part in a "multipart/form-data" request.
*
* <p>The origin of a multipart request may a browser form in which case each
* part represents a text-based form field or a file upload. Multipart requests
* may also be used outside of browsers to transfer data with any content type
* such as JSON, PDF, etc.
* part is either a {@link FormFieldPart} or a {@link FilePart}.
*
* <p>Multipart requests may also be used outside of a browser for data of any
* content type (e.g. JSON, PDF, etc).
*
* @author Sebastien Deleuze
* @author Rossen Stoyanchev
@ -53,30 +50,9 @@ public interface Part {
*/
HttpHeaders getHeaders();
/**
*
* Return the name of the file selected by the user in a browser form.
* @return the filename if defined and available
*/
Optional<String> getFilename();
/**
* Return the part content converted to a String with the charset from the
* {@code Content-Type} header or {@code UTF-8} by default.
*/
Mono<String> getContentAsString();
/**
* Return the part raw content as a stream of DataBuffer's.
*/
Flux<DataBuffer> getContent();
/**
* Transfer the file in this part to the given file destination.
* @param destination the target file
* @return completion {@code Mono} with the result of the file transfer,
* possibly {@link IllegalStateException} if the part isn't a file
*/
Mono<Void> transferTo(File destination);
}

View File

@ -52,8 +52,6 @@ import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
import org.springframework.util.StreamUtils;
/**
* {@code HttpMessageReader} for parsing {@code "multipart/form-data"} requests
@ -71,6 +69,8 @@ import org.springframework.util.StreamUtils;
*/
public class SynchronossPartHttpMessageReader implements HttpMessageReader<Part> {
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
@Override
public List<MediaType> getReadableMediaTypes() {
@ -88,7 +88,7 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader<Part>
public Flux<Part> read(ResolvableType elementType, ReactiveHttpInputMessage message,
Map<String, Object> hints) {
return Flux.create(new SynchronossPartGenerator(message));
return Flux.create(new SynchronossPartGenerator(message, this.bufferFactory));
}
@ -109,9 +109,12 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader<Part>
private final ReactiveHttpInputMessage inputMessage;
private final DataBufferFactory bufferFactory;
SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage) {
SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage, DataBufferFactory factory) {
this.inputMessage = inputMessage;
this.bufferFactory = factory;
}
@ -119,7 +122,7 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader<Part>
public void accept(FluxSink<Part> emitter) {
MultipartContext context = createMultipartContext();
NioMultipartParserListener listener = new FluxSinkAdapterListener(emitter);
NioMultipartParserListener listener = new FluxSinkAdapterListener(emitter, this.bufferFactory);
NioMultipartParser parser = Multipart.multipart(context).forNIO(listener);
this.inputMessage.getBody().subscribe(buffer -> {
@ -167,11 +170,14 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader<Part>
private final FluxSink<Part> sink;
private final DataBufferFactory bufferFactory;
private final AtomicInteger terminated = new AtomicInteger(0);
FluxSinkAdapterListener(FluxSink<Part> sink) {
FluxSinkAdapterListener(FluxSink<Part> sink, DataBufferFactory bufferFactory) {
this.sink = sink;
this.bufferFactory = bufferFactory;
}
@ -179,14 +185,17 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader<Part>
public void onPartFinished(StreamStorage storage, Map<String, List<String>> headers) {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(headers);
this.sink.next(new SynchronossPart(httpHeaders, storage));
Part part = MultipartUtils.getFileName(httpHeaders) != null ?
new SynchronossFilePart(httpHeaders, storage, this.bufferFactory) :
new DefaultSynchronossPart(httpHeaders, storage, this.bufferFactory);
this.sink.next(part);
}
@Override
public void onFormFieldPartFinished(String name, String value, Map<String, List<String>> headers) {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(headers);
this.sink.next(new SynchronossPart(httpHeaders, value));
this.sink.next(new SynchronossFormFieldPart(httpHeaders, this.bufferFactory, value));
}
@Override
@ -213,31 +222,18 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader<Part>
}
private static class SynchronossPart implements Part {
private static abstract class AbstractSynchronossPart implements Part {
private final HttpHeaders headers;
private final StreamStorage storage;
private final String content;
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
private final DataBufferFactory bufferFactory;
SynchronossPart(HttpHeaders headers, StreamStorage storage) {
AbstractSynchronossPart(HttpHeaders headers, DataBufferFactory bufferFactory) {
Assert.notNull(headers, "HttpHeaders is required");
Assert.notNull(storage, "'storage' is required");
Assert.notNull(bufferFactory, "'bufferFactory' is required");
this.headers = headers;
this.storage = storage;
this.content = null;
}
SynchronossPart(HttpHeaders headers, String content) {
Assert.notNull(headers, "HttpHeaders is required");
Assert.notNull(content, "'content' is required");
this.headers = headers;
this.storage = null;
this.content = content;
this.bufferFactory = bufferFactory;
}
@ -251,52 +247,53 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader<Part>
return this.headers;
}
@Override
public Optional<String> getFilename() {
return Optional.ofNullable(MultipartUtils.getFileName(this.headers));
protected DataBufferFactory getBufferFactory() {
return this.bufferFactory;
}
}
private static class DefaultSynchronossPart extends AbstractSynchronossPart {
private final StreamStorage storage;
DefaultSynchronossPart(HttpHeaders headers, StreamStorage storage, DataBufferFactory factory) {
super(headers, factory);
Assert.notNull(storage, "'storage' is required");
this.storage = storage;
}
@Override
public Mono<String> getContentAsString() {
if (this.content != null) {
return Mono.just(this.content);
}
try {
InputStream inputStream = this.storage.getInputStream();
Charset charset = getCharset();
return Mono.just(StreamUtils.copyToString(inputStream, charset));
}
catch (IOException e) {
return Mono.error(new IllegalStateException(
"Error while reading part content as a string", e));
}
}
private Charset getCharset() {
return Optional.ofNullable(this.headers.getContentType())
.map(MimeType::getCharset).orElse(StandardCharsets.UTF_8);
}
@Override
public Flux<DataBuffer> getContent() {
if (this.content != null) {
DataBuffer buffer = this.bufferFactory.allocateBuffer(this.content.length());
buffer.write(this.content.getBytes());
return Flux.just(buffer);
}
InputStream inputStream = this.storage.getInputStream();
return DataBufferUtils.read(inputStream, this.bufferFactory, 4096);
return DataBufferUtils.read(inputStream, getBufferFactory(), 4096);
}
protected StreamStorage getStorage() {
return this.storage;
}
}
private static class SynchronossFilePart extends DefaultSynchronossPart implements FilePart {
public SynchronossFilePart(HttpHeaders headers, StreamStorage storage, DataBufferFactory factory) {
super(headers, storage, factory);
}
@Override
public String getFilename() {
return MultipartUtils.getFileName(getHeaders());
}
@Override
public Mono<Void> transferTo(File destination) {
if (this.storage == null || !getFilename().isPresent()) {
return Mono.error(new IllegalStateException("The part does not represent a file."));
}
ReadableByteChannel input = null;
FileChannel output = null;
try {
input = Channels.newChannel(this.storage.getInputStream());
input = Channels.newChannel(getStorage().getInputStream());
output = new FileOutputStream(destination).getChannel();
long size = (input instanceof FileChannel ? ((FileChannel) input).size() : Long.MAX_VALUE);
@ -332,4 +329,34 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader<Part>
}
}
private static class SynchronossFormFieldPart extends AbstractSynchronossPart implements FormFieldPart {
private final String content;
SynchronossFormFieldPart(HttpHeaders headers, DataBufferFactory bufferFactory, String content) {
super(headers, bufferFactory);
this.content = content;
}
@Override
public String getValue() {
return this.content;
}
@Override
public Flux<DataBuffer> getContent() {
byte[] bytes = this.content.getBytes(getCharset());
DataBuffer buffer = getBufferFactory().allocateBuffer(bytes.length);
buffer.write(bytes);
return Flux.just(buffer);
}
private Charset getCharset() {
return Optional.ofNullable(MultipartUtils.getCharEncoding(getHeaders()))
.map(Charset::forName).orElse(StandardCharsets.UTF_8);
}
}
}

View File

@ -19,10 +19,12 @@ package org.springframework.web.bind.support;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.Collectors;
import reactor.core.publisher.Mono;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.codec.multipart.Part;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MultiValueMap;
@ -105,6 +107,9 @@ public class WebExchangeDataBinder extends WebDataBinder {
private static void addBindValue(Map<String, Object> params, String key, List<?> values) {
if (!CollectionUtils.isEmpty(values)) {
values = values.stream()
.map(value -> value instanceof FormFieldPart ? ((FormFieldPart) value).getValue() : value)
.collect(Collectors.toList());
params.put(key, values.size() == 1 ? values.get(0) : values);
}
}

View File

@ -39,7 +39,10 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
/**
* @author Sebastien Deleuze
@ -114,37 +117,39 @@ public class MultipartHttpMessageWriterTests {
assertEquals(5, requestParts.size());
Part part = requestParts.getFirst("name 1");
assertTrue(part instanceof FormFieldPart);
assertEquals("name 1", part.getName());
assertEquals("value 1", part.getContentAsString().block());
assertFalse(part.getFilename().isPresent());
assertEquals("value 1", ((FormFieldPart) part).getValue());
List<Part> part2 = requestParts.get("name 2");
assertEquals(2, part2.size());
part = part2.get(0);
List<Part> parts2 = requestParts.get("name 2");
assertEquals(2, parts2.size());
part = parts2.get(0);
assertTrue(part instanceof FormFieldPart);
assertEquals("name 2", part.getName());
assertEquals("value 2+1", part.getContentAsString().block());
part = part2.get(1);
assertEquals("value 2+1", ((FormFieldPart) part).getValue());
part = parts2.get(1);
assertTrue(part instanceof FormFieldPart);
assertEquals("name 2", part.getName());
assertEquals("value 2+2", part.getContentAsString().block());
assertEquals("value 2+2", ((FormFieldPart) part).getValue());
part = requestParts.getFirst("logo");
assertTrue(part instanceof FilePart);
assertEquals("logo", part.getName());
assertTrue(part.getFilename().isPresent());
assertEquals("logo.jpg", part.getFilename().get());
assertEquals("logo.jpg", ((FilePart) part).getFilename());
assertEquals(MediaType.IMAGE_JPEG, part.getHeaders().getContentType());
assertEquals(logo.getFile().length(), part.getHeaders().getContentLength());
part = requestParts.getFirst("utf8");
assertTrue(part instanceof FilePart);
assertEquals("utf8", part.getName());
assertTrue(part.getFilename().isPresent());
assertEquals("Hall\u00F6le.jpg", part.getFilename().get());
assertEquals("Hall\u00F6le.jpg", ((FilePart) part).getFilename());
assertEquals(MediaType.IMAGE_JPEG, part.getHeaders().getContentType());
assertEquals(utf8.getFile().length(), part.getHeaders().getContentLength());
part = requestParts.getFirst("json");
assertEquals("json", part.getName());
assertEquals(MediaType.APPLICATION_JSON_UTF8, part.getHeaders().getContentType());
assertEquals("{\"bar\":\"bar\"}", part.getContentAsString().block());
assertEquals("{\"bar\":\"bar\"}", ((FormFieldPart) part).getValue());
}

View File

@ -18,7 +18,6 @@ package org.springframework.http.codec.multipart;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import org.junit.Test;
import reactor.core.publisher.Flux;
@ -88,10 +87,9 @@ public class SynchronossPartHttpMessageReaderTests {
assertTrue(parts.containsKey("fooPart"));
Part part = parts.getFirst("fooPart");
assertTrue(part instanceof FilePart);
assertEquals("fooPart", part.getName());
Optional<String> filename = part.getFilename();
assertTrue(filename.isPresent());
assertEquals("foo.txt", filename.get());
assertEquals("foo.txt", ((FilePart) part).getFilename());
DataBuffer buffer = part.getContent().reduce(DataBuffer::write).block();
assertEquals(12, buffer.readableByteCount());
byte[] byteContent = new byte[12];
@ -100,10 +98,9 @@ public class SynchronossPartHttpMessageReaderTests {
assertTrue(parts.containsKey("barPart"));
part = parts.getFirst("barPart");
assertTrue(part instanceof FormFieldPart);
assertEquals("barPart", part.getName());
filename = part.getFilename();
assertFalse(filename.isPresent());
assertEquals("bar", part.getContentAsString().block());
assertEquals("bar", ((FormFieldPart) part).getValue());
}
@Test

View File

@ -30,6 +30,8 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.codec.multipart.FilePart;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.codec.multipart.Part;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
@ -99,12 +101,11 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes
private void assertFooPart(Part part) {
assertEquals("fooPart", part.getName());
Optional<String> filename = part.getFilename();
assertTrue(filename.isPresent());
assertEquals("foo.txt", filename.get());
assertTrue(part instanceof FilePart);
assertEquals("foo.txt", ((FilePart) part).getFilename());
DataBuffer buffer = part
.getContent()
.reduce((s1, s2) -> s1.write(s2))
.reduce(DataBuffer::write)
.block();
assertEquals(12, buffer.readableByteCount());
byte[] byteContent = new byte[12];
@ -114,9 +115,8 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes
private void assertBarPart(Part part) {
assertEquals("barPart", part.getName());
Optional<String> filename = part.getFilename();
assertFalse(filename.isPresent());
assertEquals("bar", part.getContentAsString().block());
assertTrue(part instanceof FormFieldPart);
assertEquals("bar", ((FormFieldPart) part).getValue());
}
}

View File

@ -17,15 +17,22 @@
package org.springframework.web.bind.support;
import java.beans.PropertyEditorSupport;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.time.Duration;
import java.util.Iterator;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Mono;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.codec.FormHttpMessageWriter;
import org.springframework.http.codec.multipart.FilePart;
import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
import org.springframework.mock.http.client.reactive.test.MockClientHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.tests.sample.beans.ITestBean;
import org.springframework.tests.sample.beans.TestBean;
@ -34,9 +41,12 @@ import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import static junit.framework.TestCase.assertFalse;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.springframework.core.ResolvableType.forClass;
import static org.springframework.core.ResolvableType.forClassWithGenerics;
/**
* Unit tests for {@link WebExchangeDataBinder}.
@ -177,39 +187,60 @@ public class WebExchangeDataBinderTests {
assertEquals("test", this.testBean.getSpouse().getName());
}
@Test
public void testMultipart() throws Exception {
private String generateForm(MultiValueMap<String, String> form) {
StringBuilder builder = new StringBuilder();
try {
for (Iterator<String> names = form.keySet().iterator(); names.hasNext();) {
String name = names.next();
for (Iterator<String> values = form.get(name).iterator(); values.hasNext();) {
String value = values.next();
builder.append(URLEncoder.encode(name, "UTF-8"));
if (value != null) {
builder.append('=');
builder.append(URLEncoder.encode(value, "UTF-8"));
if (values.hasNext()) {
builder.append('&');
}
}
}
if (names.hasNext()) {
builder.append('&');
}
}
}
catch (UnsupportedEncodingException ex) {
throw new IllegalStateException(ex);
}
return builder.toString();
MultipartBean bean = new MultipartBean();
WebExchangeDataBinder binder = new WebExchangeDataBinder(bean);
MultiValueMap<String, Object> data = new LinkedMultiValueMap<>();
data.add("name", "bar");
data.add("someList", "123");
data.add("someList", "abc");
data.add("someArray", "dec");
data.add("someArray", "456");
data.add("part", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"));
data.add("somePartList", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"));
data.add("somePartList", new ClassPathResource("org/springframework/http/server/reactive/spring.png"));
binder.bind(exchangeMultipart(data)).block(Duration.ofMillis(5000));
assertEquals("bar", bean.getName());
assertEquals(Arrays.asList("123", "abc"), bean.getSomeList());
assertArrayEquals(new String[] {"dec", "456"}, bean.getSomeArray());
assertEquals("foo.txt", bean.getPart().getFilename());
assertEquals(2, bean.getSomePartList().size());
assertEquals("foo.txt", bean.getSomePartList().get(0).getFilename());
assertEquals("spring.png", bean.getSomePartList().get(1).getFilename());
}
private ServerWebExchange exchange(MultiValueMap<String, String> formData) {
MockClientHttpRequest request = new MockClientHttpRequest(HttpMethod.POST, "/");
new FormHttpMessageWriter().write(Mono.just(formData),
forClassWithGenerics(MultiValueMap.class, String.class, String.class),
MediaType.APPLICATION_FORM_URLENCODED, request, Collections.emptyMap()).block();
return MockServerHttpRequest
.post("/")
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body(generateForm(formData))
.body(request.getBody())
.toExchange();
}
private ServerWebExchange exchangeMultipart(MultiValueMap<String, ?> multipartData) {
MockClientHttpRequest request = new MockClientHttpRequest(HttpMethod.POST, "/");
new MultipartHttpMessageWriter().write(Mono.just(multipartData), forClass(MultiValueMap.class),
MediaType.MULTIPART_FORM_DATA, request, Collections.emptyMap()).block();
return MockServerHttpRequest
.post("/")
.contentType(request.getHeaders().getContentType())
.body(request.getBody())
.toExchange();
}
@ -222,4 +253,58 @@ public class WebExchangeDataBinderTests {
}
}
private static class MultipartBean {
private String name;
private List<?> someList;
private String[] someArray;
private FilePart part;
private List<FilePart> somePartList;
public String getName() {
return this.name;
}
public void setName(String name) {
this.name = name;
}
public List<?> getSomeList() {
return this.someList;
}
public void setSomeList(List<?> someList) {
this.someList = someList;
}
public String[] getSomeArray() {
return this.someArray;
}
public void setSomeArray(String[] someArray) {
this.someArray = someArray;
}
public FilePart getPart() {
return this.part;
}
public void setPart(FilePart part) {
this.part = part;
}
public List<FilePart> getSomePartList() {
return this.somePartList;
}
public void setSomePartList(List<FilePart> somePartList) {
this.somePartList = somePartList;
}
}
}

View File

@ -237,7 +237,7 @@ public class ModelAttributeMethodArgumentResolver extends HandlerMethodArgumentR
private boolean hasErrorsArgument(MethodParameter parameter) {
int i = parameter.getParameterIndex();
Class<?>[] paramTypes = parameter.getMethod().getParameterTypes();
return (paramTypes.length > i && Errors.class.isAssignableFrom(paramTypes[i + 1]));
return (paramTypes.length > i + 1 && Errors.class.isAssignableFrom(paramTypes[i + 1]));
}
private void validateIfApplicable(WebExchangeDataBinder binder, MethodParameter parameter) {

View File

@ -27,6 +27,8 @@ import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.multipart.FilePart;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.codec.multipart.Part;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
@ -38,7 +40,6 @@ import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import static org.junit.Assert.assertEquals;
import static org.springframework.web.reactive.function.server.RequestPredicates.POST;
import static org.springframework.web.reactive.function.server.RouterFunctions.route;
@ -57,9 +58,7 @@ public class MultipartIntegrationTests extends AbstractRouterFunctionIntegration
StepVerifier
.create(result)
.consumeNextWith(response -> {
assertEquals(HttpStatus.OK, response.statusCode());
})
.consumeNextWith(response -> assertEquals(HttpStatus.OK, response.statusCode()))
.verifyComplete();
}
@ -90,8 +89,8 @@ public class MultipartIntegrationTests extends AbstractRouterFunctionIntegration
Map<String, Part> parts = map.toSingleValueMap();
try {
assertEquals(2, parts.size());
assertEquals("foo.txt", parts.get("fooPart").getFilename().get());
assertEquals("bar", parts.get("barPart").getContentAsString().block());
assertEquals("foo.txt", ((FilePart) parts.get("fooPart")).getFilename());
assertEquals("bar", ((FormFieldPart) parts.get("barPart")).getValue());
}
catch(Exception e) {
return Mono.error(e);

View File

@ -33,11 +33,13 @@ import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.multipart.FilePart;
import org.springframework.http.codec.multipart.Part;
import org.springframework.http.server.reactive.AbstractHttpHandlerIntegrationTests;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestPart;
@ -117,6 +119,21 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes
.verifyComplete();
}
@Test
public void modelAttribute() {
Mono<String> result = webClient
.post()
.uri("/modelAttribute")
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(BodyInserters.fromMultipartData(generateBody()))
.retrieve()
.bodyToMono(String.class);
StepVerifier.create(result)
.consumeNextWith(body -> assertEquals("TestBean[barPart=bar,fooPart=foo.txt]", body))
.verifyComplete();
}
private MultiValueMap<String, Object> generateBody() {
HttpHeaders fooHeaders = new HttpHeaders();
@ -135,23 +152,58 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes
static class MultipartController {
@PostMapping("/requestPart")
void part(@RequestPart Part fooPart) {
assertEquals("foo.txt", fooPart.getFilename().get());
void requestPart(@RequestPart Part fooPart) {
assertEquals("foo.txt", ((FilePart) fooPart).getFilename());
}
@PostMapping("/requestBodyMap")
Mono<String> part(@RequestBody Mono<MultiValueMap<String, Part>> parts) {
Mono<String> requestBodyMap(@RequestBody Mono<MultiValueMap<String, Part>> parts) {
return parts.map(map -> map.toSingleValueMap().entrySet().stream()
.map(Map.Entry::getKey).sorted().collect(Collectors.joining(",", "Map[", "]")));
}
@PostMapping("/requestBodyFlux")
Mono<String> part(@RequestBody Flux<Part> parts) {
Mono<String> requestBodyFlux(@RequestBody Flux<Part> parts) {
return parts.map(Part::getName).collectList()
.map(names -> names.stream().sorted().collect(Collectors.joining(",", "Flux[", "]")));
}
@PostMapping("/modelAttribute")
String modelAttribute(@ModelAttribute TestBean testBean) {
return testBean.toString();
}
}
static class TestBean {
private String barPart;
private FilePart fooPart;
public String getBarPart() {
return this.barPart;
}
public void setBarPart(String barPart) {
this.barPart = barPart;
}
public FilePart getFooPart() {
return this.fooPart;
}
public void setFooPart(FilePart fooPart) {
this.fooPart = fooPart;
}
@Override
public String toString() {
return "TestBean[barPart=" + getBarPart() + ",fooPart=" + getFooPart().getFilename() + "]";
}
}
@Configuration
@EnableWebFlux
@SuppressWarnings("unused")