Support repeatable multipart write

Closes gh-34859
This commit is contained in:
rstoyanchev 2025-06-11 16:49:21 +01:00
parent d8ac3ff31f
commit 00cc48dad4
2 changed files with 88 additions and 28 deletions

View File

@ -28,6 +28,7 @@ import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.jspecify.annotations.Nullable;
@ -485,9 +486,18 @@ public class FormHttpMessageConverter implements HttpMessageConverter<MultiValue
outputMessage.getHeaders().setContentType(contentType);
if (outputMessage instanceof StreamingHttpOutputMessage streamingOutputMessage) {
streamingOutputMessage.setBody(outputStream -> {
writeParts(outputStream, parts, boundary);
writeEnd(outputStream, boundary);
boolean repeatable = checkPartsRepeatable(parts);
streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() {
@Override
public void writeTo(OutputStream outputStream) throws IOException {
FormHttpMessageConverter.this.writeParts(outputStream, parts, boundary);
writeEnd(outputStream, boundary);
}
@Override
public boolean repeatable() {
return repeatable;
}
});
}
else {
@ -496,6 +506,35 @@ public class FormHttpMessageConverter implements HttpMessageConverter<MultiValue
}
}
@SuppressWarnings({"unchecked", "ConstantValue"})
private <T> boolean checkPartsRepeatable(MultiValueMap<String, Object> map) {
return map.entrySet().stream().allMatch(e -> e.getValue().stream().filter(Objects::nonNull).allMatch(part -> {
HttpHeaders headers = null;
Object body = part;
if (part instanceof HttpEntity<?> entity) {
headers = entity.getHeaders();
body = entity.getBody();
Assert.state(body != null, "Empty body for part '" + e.getKey() + "': " + part);
}
HttpMessageConverter<?> converter = findConverterFor(e.getKey(), headers, body);
return (converter instanceof AbstractHttpMessageConverter<?> ahmc &&
((AbstractHttpMessageConverter<T>) ahmc).supportsRepeatableWrites((T) body));
}));
}
private @Nullable HttpMessageConverter<?> findConverterFor(
String name, @Nullable HttpHeaders headers, Object body) {
Class<?> partType = body.getClass();
MediaType contentType = (headers != null ? headers.getContentType() : null);
for (HttpMessageConverter<?> converter : this.partConverters) {
if (converter.canWrite(partType, contentType)) {
return converter;
}
}
return null;
}
/**
* When {@link #setMultipartCharset(Charset)} is configured (i.e. RFC 2047,
* {@code encoded-word} syntax) we need to use ASCII for part headers, or
@ -521,32 +560,27 @@ public class FormHttpMessageConverter implements HttpMessageConverter<MultiValue
@SuppressWarnings("unchecked")
private void writePart(String name, HttpEntity<?> partEntity, OutputStream os) throws IOException {
Object partBody = partEntity.getBody();
if (partBody == null) {
throw new IllegalStateException("Empty body for part '" + name + "': " + partEntity);
}
Class<?> partType = partBody.getClass();
Assert.state(partBody != null, "Empty body for part '" + name + "': " + partEntity);
HttpHeaders partHeaders = partEntity.getHeaders();
MediaType partContentType = partHeaders.getContentType();
for (HttpMessageConverter<?> messageConverter : this.partConverters) {
if (messageConverter.canWrite(partType, partContentType)) {
Charset charset = isFilenameCharsetSet() ? StandardCharsets.US_ASCII : this.charset;
HttpOutputMessage multipartMessage = new MultipartHttpOutputMessage(os, charset);
String filename = getFilename(partBody);
ContentDisposition.Builder cd = ContentDisposition.formData()
.name(name);
if (filename != null) {
cd.filename(filename, this.multipartCharset);
}
multipartMessage.getHeaders().setContentDisposition(cd.build());
if (!partHeaders.isEmpty()) {
multipartMessage.getHeaders().putAll(partHeaders);
}
((HttpMessageConverter<Object>) messageConverter).write(partBody, partContentType, multipartMessage);
return;
HttpMessageConverter<?> converter = findConverterFor(name, partHeaders, partBody);
if (converter != null) {
Charset charset = isFilenameCharsetSet() ? StandardCharsets.US_ASCII : this.charset;
HttpOutputMessage multipartMessage = new MultipartHttpOutputMessage(os, charset);
String filename = getFilename(partBody);
ContentDisposition.Builder cd = ContentDisposition.formData().name(name);
if (filename != null) {
cd.filename(filename, this.multipartCharset);
}
multipartMessage.getHeaders().setContentDisposition(cd.build());
if (!partHeaders.isEmpty()) {
multipartMessage.getHeaders().putAll(partHeaders);
}
((HttpMessageConverter<Object>) converter).write(partBody, partContentType, multipartMessage);
return;
}
throw new HttpMessageNotWritableException("Could not write request: no suitable HttpMessageConverter " +
"found for request type [" + partType.getName() + "]");
throw new HttpMessageNotWritableException("Could not write request: " +
"no suitable HttpMessageConverter found for request type [" + partBody.getClass().getName() + "]");
}
/**

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2024 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -40,6 +40,7 @@ import org.springframework.core.io.Resource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.StreamingHttpOutputMessage;
import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter;
import org.springframework.http.converter.xml.SourceHttpMessageConverter;
import org.springframework.util.LinkedMultiValueMap;
@ -204,7 +205,7 @@ class FormHttpMessageConverterTests {
parameters.put("charset", UTF_8.name());
parameters.put("foo", "bar");
MockHttpOutputMessage outputMessage = new MockHttpOutputMessage();
StreamingMockHttpOutputMessage outputMessage = new StreamingMockHttpOutputMessage();
this.converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage);
final MediaType contentType = outputMessage.getHeaders().getContentType();
@ -248,6 +249,8 @@ class FormHttpMessageConverterTests {
item = items.get(5);
assertThat(item.getFieldName()).isEqualTo("json");
assertThat(item.getContentType()).isEqualTo("application/json");
assertThat(outputMessage.wasRepeatable()).isTrue();
}
@Test
@ -286,7 +289,7 @@ class FormHttpMessageConverterTests {
parameters.put("charset", UTF_8.name());
parameters.put("foo", "bar");
MockHttpOutputMessage outputMessage = new MockHttpOutputMessage();
StreamingMockHttpOutputMessage outputMessage = new StreamingMockHttpOutputMessage();
this.converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage);
final MediaType contentType = outputMessage.getHeaders().getContentType();
@ -330,6 +333,8 @@ class FormHttpMessageConverterTests {
item = items.get(5);
assertThat(item.getFieldName()).isEqualTo("xml");
assertThat(item.getContentType()).isEqualTo("text/xml");
assertThat(outputMessage.wasRepeatable()).isFalse();
}
@Test // SPR-13309
@ -444,6 +449,27 @@ class FormHttpMessageConverterTests {
}
private static class StreamingMockHttpOutputMessage extends MockHttpOutputMessage implements StreamingHttpOutputMessage {
private boolean repeatable;
public boolean wasRepeatable() {
return this.repeatable;
}
@Override
public void setBody(Body body) {
try {
this.repeatable = body.repeatable();
body.writeTo(getBody());
}
catch (IOException ex) {
throw new RuntimeException(ex);
}
}
}
private static class MockHttpOutputMessageRequestContext implements UploadContext {
private final MockHttpOutputMessage outputMessage;