From fc7bededd02410ce3ed0c828e82b44c170e765ce Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 3 May 2017 18:46:00 -0400 Subject: [PATCH] Support data binding for multipart requests in WebFlux Issue: SPR-14546 --- .../http/codec/multipart/FilePart.java | 45 ++++++ .../http/codec/multipart/FormFieldPart.java | 32 ++++ .../http/codec/multipart/Part.java | 32 +--- .../SynchronossPartHttpMessageReader.java | 145 +++++++++++------- .../bind/support/WebExchangeDataBinder.java | 5 + .../MultipartHttpMessageWriterTests.java | 33 ++-- ...SynchronossPartHttpMessageReaderTests.java | 11 +- .../reactive/MultipartIntegrationTests.java | 14 +- .../support/WebExchangeDataBinderTests.java | 143 +++++++++++++---- .../ModelAttributeMethodArgumentResolver.java | 2 +- .../function/MultipartIntegrationTests.java | 11 +- .../annotation/MultipartIntegrationTests.java | 60 +++++++- 12 files changed, 378 insertions(+), 155 deletions(-) create mode 100644 spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java create mode 100644 spring-web/src/main/java/org/springframework/http/codec/multipart/FormFieldPart.java diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java new file mode 100644 index 0000000000..8f39bfbddd --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +import java.io.File; + +import reactor.core.publisher.Mono; + +/** + * Specialization of {@link Part} for a file upload. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface FilePart extends Part { + + /** + * Return the name of the file selected by the user in a browser form. + */ + String getFilename(); + + + /** + * Transfer the file in this part to the given file destination. + * @param dest the target file + * @return completion {@code Mono} with the result of the file transfer, + * possibly {@link IllegalStateException} if the part isn't a file + */ + Mono transferTo(File dest); + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/FormFieldPart.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/FormFieldPart.java new file mode 100644 index 0000000000..562ea6b5bf --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/FormFieldPart.java @@ -0,0 +1,32 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +/** + * Specialization of {@link Part} for a form field. + * + * @author Rossen Stoyanchev + * @since 5.0 + */ +public interface FormFieldPart extends Part { + + /** + * Return the form field value. + */ + String getValue(); + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/Part.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/Part.java index b47bc02206..a9a80fa476 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/Part.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/Part.java @@ -16,11 +16,7 @@ package org.springframework.http.codec.multipart; -import java.io.File; -import java.util.Optional; - import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.http.HttpHeaders; @@ -29,9 +25,10 @@ import org.springframework.http.HttpHeaders; * Representation for a part in a "multipart/form-data" request. * *

The origin of a multipart request may a browser form in which case each - * part represents a text-based form field or a file upload. Multipart requests - * may also be used outside of browsers to transfer data with any content type - * such as JSON, PDF, etc. + * part is either a {@link FormFieldPart} or a {@link FilePart}. + * + *

Multipart requests may also be used outside of a browser for data of any + * content type (e.g. JSON, PDF, etc). * * @author Sebastien Deleuze * @author Rossen Stoyanchev @@ -53,30 +50,9 @@ public interface Part { */ HttpHeaders getHeaders(); - /** - * - * Return the name of the file selected by the user in a browser form. - * @return the filename if defined and available - */ - Optional getFilename(); - - /** - * Return the part content converted to a String with the charset from the - * {@code Content-Type} header or {@code UTF-8} by default. - */ - Mono getContentAsString(); - /** * Return the part raw content as a stream of DataBuffer's. */ Flux getContent(); - /** - * Transfer the file in this part to the given file destination. - * @param destination the target file - * @return completion {@code Mono} with the result of the file transfer, - * possibly {@link IllegalStateException} if the part isn't a file - */ - Mono transferTo(File destination); - } diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java index 07a144d25c..180404ceb3 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java @@ -52,8 +52,6 @@ import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.http.codec.HttpMessageReader; import org.springframework.util.Assert; -import org.springframework.util.MimeType; -import org.springframework.util.StreamUtils; /** * {@code HttpMessageReader} for parsing {@code "multipart/form-data"} requests @@ -71,6 +69,8 @@ import org.springframework.util.StreamUtils; */ public class SynchronossPartHttpMessageReader implements HttpMessageReader { + private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + @Override public List getReadableMediaTypes() { @@ -88,7 +88,7 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader public Flux read(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { - return Flux.create(new SynchronossPartGenerator(message)); + return Flux.create(new SynchronossPartGenerator(message, this.bufferFactory)); } @@ -109,9 +109,12 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader private final ReactiveHttpInputMessage inputMessage; + private final DataBufferFactory bufferFactory; - SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage) { + + SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage, DataBufferFactory factory) { this.inputMessage = inputMessage; + this.bufferFactory = factory; } @@ -119,7 +122,7 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader public void accept(FluxSink emitter) { MultipartContext context = createMultipartContext(); - NioMultipartParserListener listener = new FluxSinkAdapterListener(emitter); + NioMultipartParserListener listener = new FluxSinkAdapterListener(emitter, this.bufferFactory); NioMultipartParser parser = Multipart.multipart(context).forNIO(listener); this.inputMessage.getBody().subscribe(buffer -> { @@ -167,11 +170,14 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader private final FluxSink sink; + private final DataBufferFactory bufferFactory; + private final AtomicInteger terminated = new AtomicInteger(0); - FluxSinkAdapterListener(FluxSink sink) { + FluxSinkAdapterListener(FluxSink sink, DataBufferFactory bufferFactory) { this.sink = sink; + this.bufferFactory = bufferFactory; } @@ -179,14 +185,17 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader public void onPartFinished(StreamStorage storage, Map> headers) { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.putAll(headers); - this.sink.next(new SynchronossPart(httpHeaders, storage)); + Part part = MultipartUtils.getFileName(httpHeaders) != null ? + new SynchronossFilePart(httpHeaders, storage, this.bufferFactory) : + new DefaultSynchronossPart(httpHeaders, storage, this.bufferFactory); + this.sink.next(part); } @Override public void onFormFieldPartFinished(String name, String value, Map> headers) { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.putAll(headers); - this.sink.next(new SynchronossPart(httpHeaders, value)); + this.sink.next(new SynchronossFormFieldPart(httpHeaders, this.bufferFactory, value)); } @Override @@ -213,31 +222,18 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader } - private static class SynchronossPart implements Part { + private static abstract class AbstractSynchronossPart implements Part { private final HttpHeaders headers; - private final StreamStorage storage; - - private final String content; - - private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + private final DataBufferFactory bufferFactory; - SynchronossPart(HttpHeaders headers, StreamStorage storage) { + AbstractSynchronossPart(HttpHeaders headers, DataBufferFactory bufferFactory) { Assert.notNull(headers, "HttpHeaders is required"); - Assert.notNull(storage, "'storage' is required"); + Assert.notNull(bufferFactory, "'bufferFactory' is required"); this.headers = headers; - this.storage = storage; - this.content = null; - } - - SynchronossPart(HttpHeaders headers, String content) { - Assert.notNull(headers, "HttpHeaders is required"); - Assert.notNull(content, "'content' is required"); - this.headers = headers; - this.storage = null; - this.content = content; + this.bufferFactory = bufferFactory; } @@ -251,52 +247,53 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader return this.headers; } - @Override - public Optional getFilename() { - return Optional.ofNullable(MultipartUtils.getFileName(this.headers)); + protected DataBufferFactory getBufferFactory() { + return this.bufferFactory; + } + } + + private static class DefaultSynchronossPart extends AbstractSynchronossPart { + + private final StreamStorage storage; + + + DefaultSynchronossPart(HttpHeaders headers, StreamStorage storage, DataBufferFactory factory) { + super(headers, factory); + Assert.notNull(storage, "'storage' is required"); + this.storage = storage; } - @Override - public Mono getContentAsString() { - if (this.content != null) { - return Mono.just(this.content); - } - try { - InputStream inputStream = this.storage.getInputStream(); - Charset charset = getCharset(); - return Mono.just(StreamUtils.copyToString(inputStream, charset)); - } - catch (IOException e) { - return Mono.error(new IllegalStateException( - "Error while reading part content as a string", e)); - } - } - - private Charset getCharset() { - return Optional.ofNullable(this.headers.getContentType()) - .map(MimeType::getCharset).orElse(StandardCharsets.UTF_8); - } @Override public Flux getContent() { - if (this.content != null) { - DataBuffer buffer = this.bufferFactory.allocateBuffer(this.content.length()); - buffer.write(this.content.getBytes()); - return Flux.just(buffer); - } InputStream inputStream = this.storage.getInputStream(); - return DataBufferUtils.read(inputStream, this.bufferFactory, 4096); + return DataBufferUtils.read(inputStream, getBufferFactory(), 4096); + } + + protected StreamStorage getStorage() { + return this.storage; + } + } + + private static class SynchronossFilePart extends DefaultSynchronossPart implements FilePart { + + + public SynchronossFilePart(HttpHeaders headers, StreamStorage storage, DataBufferFactory factory) { + super(headers, storage, factory); + } + + + @Override + public String getFilename() { + return MultipartUtils.getFileName(getHeaders()); } @Override public Mono transferTo(File destination) { - if (this.storage == null || !getFilename().isPresent()) { - return Mono.error(new IllegalStateException("The part does not represent a file.")); - } ReadableByteChannel input = null; FileChannel output = null; try { - input = Channels.newChannel(this.storage.getInputStream()); + input = Channels.newChannel(getStorage().getInputStream()); output = new FileOutputStream(destination).getChannel(); long size = (input instanceof FileChannel ? ((FileChannel) input).size() : Long.MAX_VALUE); @@ -332,4 +329,34 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader } } + private static class SynchronossFormFieldPart extends AbstractSynchronossPart implements FormFieldPart { + + private final String content; + + + SynchronossFormFieldPart(HttpHeaders headers, DataBufferFactory bufferFactory, String content) { + super(headers, bufferFactory); + this.content = content; + } + + + @Override + public String getValue() { + return this.content; + } + + @Override + public Flux getContent() { + byte[] bytes = this.content.getBytes(getCharset()); + DataBuffer buffer = getBufferFactory().allocateBuffer(bytes.length); + buffer.write(bytes); + return Flux.just(buffer); + } + + private Charset getCharset() { + return Optional.ofNullable(MultipartUtils.getCharEncoding(getHeaders())) + .map(Charset::forName).orElse(StandardCharsets.UTF_8); + } + } + } diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeDataBinder.java b/spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeDataBinder.java index ebe13fe832..4ba2e0f72b 100644 --- a/spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeDataBinder.java +++ b/spring-web/src/main/java/org/springframework/web/bind/support/WebExchangeDataBinder.java @@ -19,10 +19,12 @@ package org.springframework.web.bind.support; import java.util.List; import java.util.Map; import java.util.TreeMap; +import java.util.stream.Collectors; import reactor.core.publisher.Mono; import org.springframework.beans.MutablePropertyValues; +import org.springframework.http.codec.multipart.FormFieldPart; import org.springframework.http.codec.multipart.Part; import org.springframework.util.CollectionUtils; import org.springframework.util.MultiValueMap; @@ -105,6 +107,9 @@ public class WebExchangeDataBinder extends WebDataBinder { private static void addBindValue(Map params, String key, List values) { if (!CollectionUtils.isEmpty(values)) { + values = values.stream() + .map(value -> value instanceof FormFieldPart ? ((FormFieldPart) value).getValue() : value) + .collect(Collectors.toList()); params.put(key, values.size() == 1 ? values.get(0) : values); } } diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java index ab6aa211f4..56b9338662 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java @@ -39,7 +39,10 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @author Sebastien Deleuze @@ -114,37 +117,39 @@ public class MultipartHttpMessageWriterTests { assertEquals(5, requestParts.size()); Part part = requestParts.getFirst("name 1"); + assertTrue(part instanceof FormFieldPart); assertEquals("name 1", part.getName()); - assertEquals("value 1", part.getContentAsString().block()); - assertFalse(part.getFilename().isPresent()); + assertEquals("value 1", ((FormFieldPart) part).getValue()); - List part2 = requestParts.get("name 2"); - assertEquals(2, part2.size()); - part = part2.get(0); + List parts2 = requestParts.get("name 2"); + assertEquals(2, parts2.size()); + part = parts2.get(0); + assertTrue(part instanceof FormFieldPart); assertEquals("name 2", part.getName()); - assertEquals("value 2+1", part.getContentAsString().block()); - part = part2.get(1); + assertEquals("value 2+1", ((FormFieldPart) part).getValue()); + part = parts2.get(1); + assertTrue(part instanceof FormFieldPart); assertEquals("name 2", part.getName()); - assertEquals("value 2+2", part.getContentAsString().block()); + assertEquals("value 2+2", ((FormFieldPart) part).getValue()); part = requestParts.getFirst("logo"); + assertTrue(part instanceof FilePart); assertEquals("logo", part.getName()); - assertTrue(part.getFilename().isPresent()); - assertEquals("logo.jpg", part.getFilename().get()); + assertEquals("logo.jpg", ((FilePart) part).getFilename()); assertEquals(MediaType.IMAGE_JPEG, part.getHeaders().getContentType()); assertEquals(logo.getFile().length(), part.getHeaders().getContentLength()); part = requestParts.getFirst("utf8"); + assertTrue(part instanceof FilePart); assertEquals("utf8", part.getName()); - assertTrue(part.getFilename().isPresent()); - assertEquals("Hall\u00F6le.jpg", part.getFilename().get()); + assertEquals("Hall\u00F6le.jpg", ((FilePart) part).getFilename()); assertEquals(MediaType.IMAGE_JPEG, part.getHeaders().getContentType()); assertEquals(utf8.getFile().length(), part.getHeaders().getContentLength()); part = requestParts.getFirst("json"); assertEquals("json", part.getName()); assertEquals(MediaType.APPLICATION_JSON_UTF8, part.getHeaders().getContentType()); - assertEquals("{\"bar\":\"bar\"}", part.getContentAsString().block()); + assertEquals("{\"bar\":\"bar\"}", ((FormFieldPart) part).getValue()); } diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java index 7b5e713748..12c6fd891d 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java @@ -18,7 +18,6 @@ package org.springframework.http.codec.multipart; import java.io.IOException; import java.util.Map; -import java.util.Optional; import org.junit.Test; import reactor.core.publisher.Flux; @@ -88,10 +87,9 @@ public class SynchronossPartHttpMessageReaderTests { assertTrue(parts.containsKey("fooPart")); Part part = parts.getFirst("fooPart"); + assertTrue(part instanceof FilePart); assertEquals("fooPart", part.getName()); - Optional filename = part.getFilename(); - assertTrue(filename.isPresent()); - assertEquals("foo.txt", filename.get()); + assertEquals("foo.txt", ((FilePart) part).getFilename()); DataBuffer buffer = part.getContent().reduce(DataBuffer::write).block(); assertEquals(12, buffer.readableByteCount()); byte[] byteContent = new byte[12]; @@ -100,10 +98,9 @@ public class SynchronossPartHttpMessageReaderTests { assertTrue(parts.containsKey("barPart")); part = parts.getFirst("barPart"); + assertTrue(part instanceof FormFieldPart); assertEquals("barPart", part.getName()); - filename = part.getFilename(); - assertFalse(filename.isPresent()); - assertEquals("bar", part.getContentAsString().block()); + assertEquals("bar", ((FormFieldPart) part).getValue()); } @Test diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/MultipartIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/MultipartIntegrationTests.java index ed06e67c63..d0a2ff518f 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/MultipartIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/MultipartIntegrationTests.java @@ -30,6 +30,8 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; +import org.springframework.http.codec.multipart.FilePart; +import org.springframework.http.codec.multipart.FormFieldPart; import org.springframework.http.codec.multipart.Part; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -99,12 +101,11 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes private void assertFooPart(Part part) { assertEquals("fooPart", part.getName()); - Optional filename = part.getFilename(); - assertTrue(filename.isPresent()); - assertEquals("foo.txt", filename.get()); + assertTrue(part instanceof FilePart); + assertEquals("foo.txt", ((FilePart) part).getFilename()); DataBuffer buffer = part .getContent() - .reduce((s1, s2) -> s1.write(s2)) + .reduce(DataBuffer::write) .block(); assertEquals(12, buffer.readableByteCount()); byte[] byteContent = new byte[12]; @@ -114,9 +115,8 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes private void assertBarPart(Part part) { assertEquals("barPart", part.getName()); - Optional filename = part.getFilename(); - assertFalse(filename.isPresent()); - assertEquals("bar", part.getContentAsString().block()); + assertTrue(part instanceof FormFieldPart); + assertEquals("bar", ((FormFieldPart) part).getValue()); } } diff --git a/spring-web/src/test/java/org/springframework/web/bind/support/WebExchangeDataBinderTests.java b/spring-web/src/test/java/org/springframework/web/bind/support/WebExchangeDataBinderTests.java index 45fd82a5ed..cc85cf430e 100644 --- a/spring-web/src/test/java/org/springframework/web/bind/support/WebExchangeDataBinderTests.java +++ b/spring-web/src/test/java/org/springframework/web/bind/support/WebExchangeDataBinderTests.java @@ -17,15 +17,22 @@ package org.springframework.web.bind.support; import java.beans.PropertyEditorSupport; -import java.io.UnsupportedEncodingException; -import java.net.URLEncoder; import java.time.Duration; -import java.util.Iterator; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.http.codec.FormHttpMessageWriter; +import org.springframework.http.codec.multipart.FilePart; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.mock.http.client.reactive.test.MockClientHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.tests.sample.beans.ITestBean; import org.springframework.tests.sample.beans.TestBean; @@ -34,9 +41,12 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.server.ServerWebExchange; import static junit.framework.TestCase.assertFalse; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.springframework.core.ResolvableType.forClass; +import static org.springframework.core.ResolvableType.forClassWithGenerics; /** * Unit tests for {@link WebExchangeDataBinder}. @@ -177,39 +187,60 @@ public class WebExchangeDataBinderTests { assertEquals("test", this.testBean.getSpouse().getName()); } + @Test + public void testMultipart() throws Exception { - private String generateForm(MultiValueMap form) { - StringBuilder builder = new StringBuilder(); - try { - for (Iterator names = form.keySet().iterator(); names.hasNext();) { - String name = names.next(); - for (Iterator values = form.get(name).iterator(); values.hasNext();) { - String value = values.next(); - builder.append(URLEncoder.encode(name, "UTF-8")); - if (value != null) { - builder.append('='); - builder.append(URLEncoder.encode(value, "UTF-8")); - if (values.hasNext()) { - builder.append('&'); - } - } - } - if (names.hasNext()) { - builder.append('&'); - } - } - } - catch (UnsupportedEncodingException ex) { - throw new IllegalStateException(ex); - } - return builder.toString(); + MultipartBean bean = new MultipartBean(); + WebExchangeDataBinder binder = new WebExchangeDataBinder(bean); + + MultiValueMap data = new LinkedMultiValueMap<>(); + data.add("name", "bar"); + data.add("someList", "123"); + data.add("someList", "abc"); + data.add("someArray", "dec"); + data.add("someArray", "456"); + data.add("part", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt")); + data.add("somePartList", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt")); + data.add("somePartList", new ClassPathResource("org/springframework/http/server/reactive/spring.png")); + binder.bind(exchangeMultipart(data)).block(Duration.ofMillis(5000)); + + assertEquals("bar", bean.getName()); + assertEquals(Arrays.asList("123", "abc"), bean.getSomeList()); + assertArrayEquals(new String[] {"dec", "456"}, bean.getSomeArray()); + assertEquals("foo.txt", bean.getPart().getFilename()); + assertEquals(2, bean.getSomePartList().size()); + assertEquals("foo.txt", bean.getSomePartList().get(0).getFilename()); + assertEquals("spring.png", bean.getSomePartList().get(1).getFilename()); } + + private ServerWebExchange exchange(MultiValueMap formData) { + + MockClientHttpRequest request = new MockClientHttpRequest(HttpMethod.POST, "/"); + + new FormHttpMessageWriter().write(Mono.just(formData), + forClassWithGenerics(MultiValueMap.class, String.class, String.class), + MediaType.APPLICATION_FORM_URLENCODED, request, Collections.emptyMap()).block(); + return MockServerHttpRequest .post("/") .contentType(MediaType.APPLICATION_FORM_URLENCODED) - .body(generateForm(formData)) + .body(request.getBody()) + .toExchange(); + } + + private ServerWebExchange exchangeMultipart(MultiValueMap multipartData) { + + MockClientHttpRequest request = new MockClientHttpRequest(HttpMethod.POST, "/"); + + new MultipartHttpMessageWriter().write(Mono.just(multipartData), forClass(MultiValueMap.class), + MediaType.MULTIPART_FORM_DATA, request, Collections.emptyMap()).block(); + + return MockServerHttpRequest + .post("/") + .contentType(request.getHeaders().getContentType()) + .body(request.getBody()) .toExchange(); } @@ -222,4 +253,58 @@ public class WebExchangeDataBinderTests { } } + private static class MultipartBean { + + private String name; + + private List someList; + + private String[] someArray; + + private FilePart part; + + private List somePartList; + + + public String getName() { + return this.name; + } + + public void setName(String name) { + this.name = name; + } + + public List getSomeList() { + return this.someList; + } + + public void setSomeList(List someList) { + this.someList = someList; + } + + public String[] getSomeArray() { + return this.someArray; + } + + public void setSomeArray(String[] someArray) { + this.someArray = someArray; + } + + public FilePart getPart() { + return this.part; + } + + public void setPart(FilePart part) { + this.part = part; + } + + public List getSomePartList() { + return this.somePartList; + } + + public void setSomePartList(List somePartList) { + this.somePartList = somePartList; + } + } + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelAttributeMethodArgumentResolver.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelAttributeMethodArgumentResolver.java index f880e27780..17696c60c5 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelAttributeMethodArgumentResolver.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ModelAttributeMethodArgumentResolver.java @@ -237,7 +237,7 @@ public class ModelAttributeMethodArgumentResolver extends HandlerMethodArgumentR private boolean hasErrorsArgument(MethodParameter parameter) { int i = parameter.getParameterIndex(); Class[] paramTypes = parameter.getMethod().getParameterTypes(); - return (paramTypes.length > i && Errors.class.isAssignableFrom(paramTypes[i + 1])); + return (paramTypes.length > i + 1 && Errors.class.isAssignableFrom(paramTypes[i + 1])); } private void validateIfApplicable(WebExchangeDataBinder binder, MethodParameter parameter) { diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java index a42acfc052..624239f931 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java @@ -27,6 +27,8 @@ import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.codec.multipart.FilePart; +import org.springframework.http.codec.multipart.FormFieldPart; import org.springframework.http.codec.multipart.Part; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -38,7 +40,6 @@ import org.springframework.web.reactive.function.server.ServerRequest; import org.springframework.web.reactive.function.server.ServerResponse; import static org.junit.Assert.assertEquals; - import static org.springframework.web.reactive.function.server.RequestPredicates.POST; import static org.springframework.web.reactive.function.server.RouterFunctions.route; @@ -57,9 +58,7 @@ public class MultipartIntegrationTests extends AbstractRouterFunctionIntegration StepVerifier .create(result) - .consumeNextWith(response -> { - assertEquals(HttpStatus.OK, response.statusCode()); - }) + .consumeNextWith(response -> assertEquals(HttpStatus.OK, response.statusCode())) .verifyComplete(); } @@ -90,8 +89,8 @@ public class MultipartIntegrationTests extends AbstractRouterFunctionIntegration Map parts = map.toSingleValueMap(); try { assertEquals(2, parts.size()); - assertEquals("foo.txt", parts.get("fooPart").getFilename().get()); - assertEquals("bar", parts.get("barPart").getContentAsString().block()); + assertEquals("foo.txt", ((FilePart) parts.get("fooPart")).getFilename()); + assertEquals("bar", ((FormFieldPart) parts.get("barPart")).getValue()); } catch(Exception e) { return Mono.error(e); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java index 07924c3a00..c2d026c510 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java @@ -33,11 +33,13 @@ import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.codec.multipart.FilePart; import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.reactive.AbstractHttpHandlerIntegrationTests; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestPart; @@ -117,6 +119,21 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes .verifyComplete(); } + @Test + public void modelAttribute() { + Mono result = webClient + .post() + .uri("/modelAttribute") + .contentType(MediaType.MULTIPART_FORM_DATA) + .body(BodyInserters.fromMultipartData(generateBody())) + .retrieve() + .bodyToMono(String.class); + + StepVerifier.create(result) + .consumeNextWith(body -> assertEquals("TestBean[barPart=bar,fooPart=foo.txt]", body)) + .verifyComplete(); + } + private MultiValueMap generateBody() { HttpHeaders fooHeaders = new HttpHeaders(); @@ -135,23 +152,58 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes static class MultipartController { @PostMapping("/requestPart") - void part(@RequestPart Part fooPart) { - assertEquals("foo.txt", fooPart.getFilename().get()); + void requestPart(@RequestPart Part fooPart) { + assertEquals("foo.txt", ((FilePart) fooPart).getFilename()); } @PostMapping("/requestBodyMap") - Mono part(@RequestBody Mono> parts) { + Mono requestBodyMap(@RequestBody Mono> parts) { return parts.map(map -> map.toSingleValueMap().entrySet().stream() .map(Map.Entry::getKey).sorted().collect(Collectors.joining(",", "Map[", "]"))); } @PostMapping("/requestBodyFlux") - Mono part(@RequestBody Flux parts) { + Mono requestBodyFlux(@RequestBody Flux parts) { return parts.map(Part::getName).collectList() .map(names -> names.stream().sorted().collect(Collectors.joining(",", "Flux[", "]"))); } + + @PostMapping("/modelAttribute") + String modelAttribute(@ModelAttribute TestBean testBean) { + return testBean.toString(); + } } + static class TestBean { + + private String barPart; + + private FilePart fooPart; + + + public String getBarPart() { + return this.barPart; + } + + public void setBarPart(String barPart) { + this.barPart = barPart; + } + + public FilePart getFooPart() { + return this.fooPart; + } + + public void setFooPart(FilePart fooPart) { + this.fooPart = fooPart; + } + + @Override + public String toString() { + return "TestBean[barPart=" + getBarPart() + ",fooPart=" + getFooPart().getFilename() + "]"; + } + } + + @Configuration @EnableWebFlux @SuppressWarnings("unused")