diff --git a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java index 8c8f67432ea..97011e72d21 100644 --- a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java +++ b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java @@ -22,7 +22,6 @@ import java.util.Map; import java.util.function.Consumer; import org.reactivestreams.Publisher; -import reactor.core.publisher.Mono; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; @@ -31,6 +30,7 @@ import org.springframework.core.io.buffer.DataBuffer; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.codec.multipart.FilePart; import org.springframework.http.codec.multipart.Part; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; @@ -95,8 +95,9 @@ public final class MultipartBodyBuilder { *
  • String -- form field *
  • {@link org.springframework.core.io.Resource Resource} -- file part *
  • Object -- content to be encoded (e.g. to JSON) - *
  • HttpEntity -- part content and headers although generally it's - * easier to add headers through the returned builder
  • + *
  • {@link HttpEntity} -- part content and headers although generally it's + * easier to add headers through the returned builder + *
  • {@link Part} -- a part from a server request * * @param name the name of the part to add * @param part the part data @@ -117,10 +118,21 @@ public final class MultipartBodyBuilder { Assert.hasLength(name, "'name' must not be empty"); Assert.notNull(part, "'part' must not be null"); - if (part instanceof PublisherEntity) { - PublisherPartBuilder builder = new PublisherPartBuilder<>((PublisherEntity) part); + if (part instanceof Part) { + PartBuilder builder = asyncPart(name, ((Part) part).content(), DataBuffer.class); if (contentType != null) { - builder.header(HttpHeaders.CONTENT_TYPE, contentType.toString()); + builder.contentType(contentType); + } + if (part instanceof FilePart) { + builder.filename(((FilePart) part).filename()); + } + return builder; + } + + if (part instanceof PublisherEntity) { + PublisherPartBuilder builder = new PublisherPartBuilder<>(name, (PublisherEntity) part); + if (contentType != null) { + builder.contentType(contentType); } this.parts.add(name, builder); return builder; @@ -144,9 +156,9 @@ public final class MultipartBodyBuilder { " or MultipartBodyBuilder.PublisherEntity"); } - DefaultPartBuilder builder = new DefaultPartBuilder(partHeaders, partBody); + DefaultPartBuilder builder = new DefaultPartBuilder(name, partHeaders, partBody); if (contentType != null) { - builder.header(HttpHeaders.CONTENT_TYPE, contentType.toString()); + builder.contentType(contentType); } this.parts.add(name, builder); return builder; @@ -165,15 +177,9 @@ public final class MultipartBodyBuilder { Assert.notNull(publisher, "'publisher' must not be null"); Assert.notNull(elementClass, "'elementClass' must not be null"); - if (Part.class.isAssignableFrom(elementClass)) { - publisher = (P) Mono.from(publisher).flatMapMany(p -> ((Part) p).content()); - elementClass = (Class) DataBuffer.class; - } - - PublisherPartBuilder builder = new PublisherPartBuilder<>(null, publisher, elementClass); + PublisherPartBuilder builder = new PublisherPartBuilder<>(name, null, publisher, elementClass); this.parts.add(name, builder); return builder; - } /** @@ -191,7 +197,7 @@ public final class MultipartBodyBuilder { Assert.notNull(publisher, "'publisher' must not be null"); Assert.notNull(typeReference, "'typeReference' must not be null"); - PublisherPartBuilder builder = new PublisherPartBuilder<>(null, publisher, typeReference); + PublisherPartBuilder builder = new PublisherPartBuilder<>(name, null, publisher, typeReference); this.parts.add(name, builder); return builder; } @@ -216,6 +222,24 @@ public final class MultipartBodyBuilder { */ public interface PartBuilder { + /** + * Set the {@linkplain MediaType media type} of the part. + * @param contentType the content type + * @see HttpHeaders#setContentType(MediaType) + * @since 5.2 + */ + PartBuilder contentType(MediaType contentType); + + /** + * Set the filename parameter for a file part. This should not be + * necessary with {@link org.springframework.core.io.Resource Resource} + * based parts that expose a filename but may be useful for + * {@link Publisher} parts. + * @param filename the filename to set on the Content-Disposition + * @since 5.2 + */ + PartBuilder filename(String filename); + /** * Add part header values. * @param headerName the part header name @@ -236,17 +260,32 @@ public final class MultipartBodyBuilder { private static class DefaultPartBuilder implements PartBuilder { + private final String name; + @Nullable protected HttpHeaders headers; @Nullable protected final Object body; - public DefaultPartBuilder(@Nullable HttpHeaders headers, @Nullable Object body) { + public DefaultPartBuilder(String name, @Nullable HttpHeaders headers, @Nullable Object body) { + this.name = name; this.headers = headers; this.body = body; } + @Override + public PartBuilder contentType(MediaType contentType) { + initHeadersIfNecessary().setContentType(contentType); + return this; + } + + @Override + public PartBuilder filename(String filename) { + initHeadersIfNecessary().setContentDispositionFormData(this.name, filename); + return this; + } + @Override public PartBuilder header(String headerName, String... headerValues) { initHeadersIfNecessary().addAll(headerName, Arrays.asList(headerValues)); @@ -276,18 +315,20 @@ public final class MultipartBodyBuilder { private final ResolvableType resolvableType; - public PublisherPartBuilder(@Nullable HttpHeaders headers, P body, Class elementClass) { - super(headers, body); + public PublisherPartBuilder(String name, @Nullable HttpHeaders headers, P body, Class elementClass) { + super(name, headers, body); this.resolvableType = ResolvableType.forClass(elementClass); } - public PublisherPartBuilder(@Nullable HttpHeaders headers, P body, ParameterizedTypeReference typeRef) { - super(headers, body); + public PublisherPartBuilder(String name, @Nullable HttpHeaders headers, P body, + ParameterizedTypeReference typeRef) { + + super(name, headers, body); this.resolvableType = ResolvableType.forType(typeRef); } - public PublisherPartBuilder(PublisherEntity other) { - super(other.getHeaders(), other.getBody()); + public PublisherPartBuilder(String name, PublisherEntity other) { + super(name, other.getHeaders(), other.getBody()); this.resolvableType = other.getResolvableType(); } diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java index 0886d2c83fb..af07aa25126 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java @@ -93,8 +93,9 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTestCas this.bufferFactory.wrap("Bb".getBytes(StandardCharsets.UTF_8)), this.bufferFactory.wrap("Cc".getBytes(StandardCharsets.UTF_8)) ); - Part mockPart = mock(Part.class); + FilePart mockPart = mock(FilePart.class); given(mockPart.content()).willReturn(bufferPublisher); + given(mockPart.filename()).willReturn("file.txt"); MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); bodyBuilder.part("name 1", "value 1"); @@ -104,7 +105,7 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTestCas bodyBuilder.part("utf8", utf8); bodyBuilder.part("json", new Foo("bar"), MediaType.APPLICATION_JSON); bodyBuilder.asyncPart("publisher", Flux.just("foo", "bar", "baz"), String.class); - bodyBuilder.asyncPart("partPublisher", Mono.just(mockPart), Part.class); + bodyBuilder.part("filePublisher", mockPart); Mono>> result = Mono.just(bodyBuilder.build()); Map hints = Collections.emptyMap(); @@ -159,8 +160,9 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTestCas value = decodeToString(part); assertThat(value).isEqualTo("foobarbaz"); - part = requestParts.getFirst("partPublisher"); - assertThat(part.name()).isEqualTo("partPublisher"); + part = requestParts.getFirst("filePublisher"); + assertThat(part.name()).isEqualTo("filePublisher"); + assertThat(((FilePart) part).filename()).isEqualTo("file.txt"); value = decodeToString(part); assertThat(value).isEqualTo("AaBbCc"); } diff --git a/src/docs/asciidoc/web/webflux-webclient.adoc b/src/docs/asciidoc/web/webflux-webclient.adoc index b3519ba4ef1..7050c1cacfe 100644 --- a/src/docs/asciidoc/web/webflux-webclient.adoc +++ b/src/docs/asciidoc/web/webflux-webclient.adoc @@ -414,8 +414,9 @@ multipart request. The following example shows how to create a `MultiValueMap> parts = builder.build(); ----