diff --git a/build.gradle b/build.gradle index 207e2d37023..a997ee1705b 100644 --- a/build.gradle +++ b/build.gradle @@ -76,7 +76,7 @@ configure(allprojects) { project -> ext.junitPlatformVersion = '1.0.0-M4' ext.log4jVersion = '2.8.2' ext.nettyVersion = "4.1.11.Final" - ext.niomultipartVersion = "1.0.2" + ext.niomultipartVersion = "1.1.0" ext.okhttp3Version = "3.8.0" ext.poiVersion = "3.16" ext.protobufVersion = "3.3.1" diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java index a2713254ab5..0d072a18c20 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java @@ -127,7 +127,7 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader Charset charset = Optional.ofNullable(mediaType.getCharset()).orElse(StandardCharsets.UTF_8); MultipartContext context = new MultipartContext(mediaType.toString(), length, charset.name()); - NioMultipartParserListener listener = new FluxSinkAdapterListener(emitter, this.bufferFactory); + NioMultipartParserListener listener = new FluxSinkAdapterListener(emitter, this.bufferFactory, context); NioMultipartParser parser = Multipart.multipart(context).forNIO(listener); this.inputMessage.getBody().subscribe(buffer -> { @@ -167,12 +167,15 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader private final DataBufferFactory bufferFactory; + private final MultipartContext context; + private final AtomicInteger terminated = new AtomicInteger(0); - FluxSinkAdapterListener(FluxSink sink, DataBufferFactory bufferFactory) { + FluxSinkAdapterListener(FluxSink sink, DataBufferFactory bufferFactory, MultipartContext context) { this.sink = sink; this.bufferFactory = bufferFactory; + this.context = context; } @@ -180,7 +183,21 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader public void onPartFinished(StreamStorage storage, Map> headers) { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.putAll(headers); - this.sink.next(createPart(httpHeaders, storage)); + this.sink.next(createPart(storage, httpHeaders)); + } + + private Part createPart(StreamStorage storage, HttpHeaders httpHeaders) { + String fileName = MultipartUtils.getFileName(httpHeaders); + if (fileName != null) { + return new SynchronossFilePart(httpHeaders, storage, fileName, this.bufferFactory); + } + else if (MultipartUtils.isFormField(httpHeaders, this.context)) { + String value = MultipartUtils.readFormParameterValue(storage, httpHeaders); + return new SynchronossFormFieldPart(httpHeaders, this.bufferFactory, value); + } + else { + return new DefaultSynchronossPart(httpHeaders, storage, this.bufferFactory); + } } private Part createPart(HttpHeaders httpHeaders, StreamStorage storage) { @@ -190,13 +207,6 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader new DefaultSynchronossPart(httpHeaders, storage, this.bufferFactory); } - @Override - public void onFormFieldPartFinished(String name, String value, Map> headers) { - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.putAll(headers); - this.sink.next(new SynchronossFormFieldPart(httpHeaders, this.bufferFactory, value)); - } - @Override public void onError(String message, Throwable cause) { if (this.terminated.getAndIncrement() == 0) { diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java index a0eb71bc14b..54845da5d00 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java @@ -16,6 +16,7 @@ package org.springframework.http.codec.multipart; +import java.time.Duration; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -26,6 +27,7 @@ import reactor.core.publisher.Mono; import org.springframework.core.ResolvableType; import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.http.HttpEntity; @@ -149,7 +151,13 @@ public class MultipartHttpMessageWriterTests { part = requestParts.getFirst("json"); assertEquals("json", part.name()); assertEquals(MediaType.APPLICATION_JSON_UTF8, part.headers().getContentType()); - assertEquals("{\"bar\":\"bar\"}", ((FormFieldPart) part).value()); + + String value = StringDecoder.textPlainOnly(false).decodeToMono(part.content(), + ResolvableType.forClass(String.class), MediaType.TEXT_PLAIN, + Collections.emptyMap()).block(Duration.ZERO); + + assertEquals("{\"bar\":\"bar\"}", value); + }