diff --git a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java index 80289e4fd87..22d39766aca 100644 --- a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java +++ b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.http.client; import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import org.reactivestreams.Publisher; @@ -187,6 +188,13 @@ public final class MultipartBodyBuilder { * @see HttpHeaders#add(String, String) */ PartBuilder header(String headerName, String... headerValues); + + /** + * Manipulate the part's headers with the given consumer. + * @param headersConsumer a function that consumes the {@code HttpHeaders} + * @return this builder + */ + PartBuilder headers(Consumer headersConsumer); } @@ -208,6 +216,13 @@ public final class MultipartBodyBuilder { return this; } + @Override + public PartBuilder headers(Consumer headersConsumer) { + Assert.notNull(headersConsumer, "'headersConsumer' must not be null"); + headersConsumer.accept(this.headers); + return this; + } + public HttpEntity build() { return new HttpEntity<>(this.body, this.headers); } diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java index 7cca83a781b..5e7b23b9da0 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -229,12 +229,13 @@ public class MultipartHttpMessageWriter implements HttpMessageWriter Flux encodePart(byte[] boundary, String name, T value) { MultipartHttpOutputMessage outputMessage = new MultipartHttpOutputMessage(this.bufferFactory, getCharset()); + HttpHeaders outputHeaders = outputMessage.getHeaders(); T body; ResolvableType resolvableType = null; if (value instanceof HttpEntity) { HttpEntity httpEntity = (HttpEntity) value; - outputMessage.getHeaders().putAll(httpEntity.getHeaders()); + outputHeaders.putAll(httpEntity.getHeaders()); body = httpEntity.getBody(); Assert.state(body != null, "MultipartHttpMessageWriter only supports HttpEntity with body"); @@ -247,24 +248,24 @@ public class MultipartHttpMessageWriter implements HttpMessageWriter) body).doOnNext(o -> { - outputMessage.getHeaders().setContentDispositionFormData(name, ((Resource) o).getFilename()); - }); - } - else { - outputMessage.getHeaders().setContentDispositionFormData(name, null); + if (!outputHeaders.containsKey(HttpHeaders.CONTENT_DISPOSITION)) { + if (body instanceof Resource) { + outputHeaders.setContentDispositionFormData(name, ((Resource) body).getFilename()); + } + else if (Resource.class.equals(resolvableType.getRawClass())) { + body = (T) Mono.from((Publisher) body).doOnNext(o -> outputHeaders + .setContentDispositionFormData(name, ((Resource) o).getFilename())); + } + else { + outputHeaders.setContentDispositionFormData(name, null); + } } - MediaType contentType = outputMessage.getHeaders().getContentType(); + MediaType contentType = outputHeaders.getContentType(); final ResolvableType finalBodyType = resolvableType; Optional> writer = this.partWriters.stream() diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java index d6db8d2ffa5..3442e284a0a 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,9 @@ import org.springframework.core.ResolvableType; import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.HttpEntity; import org.springframework.http.MediaType; import org.springframework.http.client.MultipartBodyBuilder; @@ -191,6 +194,42 @@ public class MultipartHttpMessageWriterTests { this.writer.write(result, null, MediaType.MULTIPART_FORM_DATA, response, hints).block(); } + @Test // SPR-16376 + public void customContentDisposition() throws IOException { + Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); + Flux buffers = DataBufferUtils.read(logo, new DefaultDataBufferFactory(), 1024); + long contentLength = logo.contentLength(); + + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); + bodyBuilder.part("resource", logo) + .headers(h -> h.setContentDispositionFormData("resource", "spring.jpg")); + bodyBuilder.asyncPart("buffers", buffers, DataBuffer.class) + .headers(h -> { + h.setContentDispositionFormData("buffers", "buffers.jpg"); + h.setContentType(MediaType.IMAGE_JPEG); + h.setContentLength(contentLength); + }); + + MultiValueMap> multipartData = bodyBuilder.build(); + + MockServerHttpResponse response = new MockServerHttpResponse(); + Map hints = Collections.emptyMap(); + this.writer.write(Mono.just(multipartData), null, MediaType.MULTIPART_FORM_DATA, response, hints).block(); + + MultiValueMap requestParts = parse(response, hints); + assertEquals(2, requestParts.size()); + + Part part = requestParts.getFirst("resource"); + assertTrue(part instanceof FilePart); + assertEquals("spring.jpg", ((FilePart) part).filename()); + assertEquals(logo.getFile().length(), part.headers().getContentLength()); + + part = requestParts.getFirst("buffers"); + assertTrue(part instanceof FilePart); + assertEquals("buffers.jpg", ((FilePart) part).filename()); + assertEquals(logo.getFile().length(), part.headers().getContentLength()); + } + private MultiValueMap parse(MockServerHttpResponse response, Map hints) { MediaType contentType = response.getHeaders().getContentType(); assertNotNull("No boundary found", contentType.getParameter("boundary"));