Support repeatable multipart write

Closes gh-34859
This commit is contained in:
rstoyanchev 2025-06-11 16:49:21 +01:00
parent d8ac3ff31f
commit 00cc48dad4
2 changed files with 88 additions and 28 deletions

View File

@ -28,6 +28,7 @@ import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
@ -485,9 +486,18 @@ public class FormHttpMessageConverter implements HttpMessageConverter<MultiValue
outputMessage.getHeaders().setContentType(contentType); outputMessage.getHeaders().setContentType(contentType);
if (outputMessage instanceof StreamingHttpOutputMessage streamingOutputMessage) { if (outputMessage instanceof StreamingHttpOutputMessage streamingOutputMessage) {
streamingOutputMessage.setBody(outputStream -> { boolean repeatable = checkPartsRepeatable(parts);
writeParts(outputStream, parts, boundary); streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() {
@Override
public void writeTo(OutputStream outputStream) throws IOException {
FormHttpMessageConverter.this.writeParts(outputStream, parts, boundary);
writeEnd(outputStream, boundary); writeEnd(outputStream, boundary);
}
@Override
public boolean repeatable() {
return repeatable;
}
}); });
} }
else { else {
@ -496,6 +506,35 @@ public class FormHttpMessageConverter implements HttpMessageConverter<MultiValue
} }
} }
@SuppressWarnings({"unchecked", "ConstantValue"})
private <T> boolean checkPartsRepeatable(MultiValueMap<String, Object> map) {
return map.entrySet().stream().allMatch(e -> e.getValue().stream().filter(Objects::nonNull).allMatch(part -> {
HttpHeaders headers = null;
Object body = part;
if (part instanceof HttpEntity<?> entity) {
headers = entity.getHeaders();
body = entity.getBody();
Assert.state(body != null, "Empty body for part '" + e.getKey() + "': " + part);
}
HttpMessageConverter<?> converter = findConverterFor(e.getKey(), headers, body);
return (converter instanceof AbstractHttpMessageConverter<?> ahmc &&
((AbstractHttpMessageConverter<T>) ahmc).supportsRepeatableWrites((T) body));
}));
}
private @Nullable HttpMessageConverter<?> findConverterFor(
String name, @Nullable HttpHeaders headers, Object body) {
Class<?> partType = body.getClass();
MediaType contentType = (headers != null ? headers.getContentType() : null);
for (HttpMessageConverter<?> converter : this.partConverters) {
if (converter.canWrite(partType, contentType)) {
return converter;
}
}
return null;
}
/** /**
* When {@link #setMultipartCharset(Charset)} is configured (i.e. RFC 2047, * When {@link #setMultipartCharset(Charset)} is configured (i.e. RFC 2047,
* {@code encoded-word} syntax) we need to use ASCII for part headers, or * {@code encoded-word} syntax) we need to use ASCII for part headers, or
@ -521,19 +560,15 @@ public class FormHttpMessageConverter implements HttpMessageConverter<MultiValue
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private void writePart(String name, HttpEntity<?> partEntity, OutputStream os) throws IOException { private void writePart(String name, HttpEntity<?> partEntity, OutputStream os) throws IOException {
Object partBody = partEntity.getBody(); Object partBody = partEntity.getBody();
if (partBody == null) { Assert.state(partBody != null, "Empty body for part '" + name + "': " + partEntity);
throw new IllegalStateException("Empty body for part '" + name + "': " + partEntity);
}
Class<?> partType = partBody.getClass();
HttpHeaders partHeaders = partEntity.getHeaders(); HttpHeaders partHeaders = partEntity.getHeaders();
MediaType partContentType = partHeaders.getContentType(); MediaType partContentType = partHeaders.getContentType();
for (HttpMessageConverter<?> messageConverter : this.partConverters) { HttpMessageConverter<?> converter = findConverterFor(name, partHeaders, partBody);
if (messageConverter.canWrite(partType, partContentType)) { if (converter != null) {
Charset charset = isFilenameCharsetSet() ? StandardCharsets.US_ASCII : this.charset; Charset charset = isFilenameCharsetSet() ? StandardCharsets.US_ASCII : this.charset;
HttpOutputMessage multipartMessage = new MultipartHttpOutputMessage(os, charset); HttpOutputMessage multipartMessage = new MultipartHttpOutputMessage(os, charset);
String filename = getFilename(partBody); String filename = getFilename(partBody);
ContentDisposition.Builder cd = ContentDisposition.formData() ContentDisposition.Builder cd = ContentDisposition.formData().name(name);
.name(name);
if (filename != null) { if (filename != null) {
cd.filename(filename, this.multipartCharset); cd.filename(filename, this.multipartCharset);
} }
@ -541,12 +576,11 @@ public class FormHttpMessageConverter implements HttpMessageConverter<MultiValue
if (!partHeaders.isEmpty()) { if (!partHeaders.isEmpty()) {
multipartMessage.getHeaders().putAll(partHeaders); multipartMessage.getHeaders().putAll(partHeaders);
} }
((HttpMessageConverter<Object>) messageConverter).write(partBody, partContentType, multipartMessage); ((HttpMessageConverter<Object>) converter).write(partBody, partContentType, multipartMessage);
return; return;
} }
} throw new HttpMessageNotWritableException("Could not write request: " +
throw new HttpMessageNotWritableException("Could not write request: no suitable HttpMessageConverter " + "no suitable HttpMessageConverter found for request type [" + partBody.getClass().getName() + "]");
"found for request type [" + partType.getName() + "]");
} }
/** /**

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2024 the original author or authors. * Copyright 2002-2025 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -40,6 +40,7 @@ import org.springframework.core.io.Resource;
import org.springframework.http.HttpEntity; import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.StreamingHttpOutputMessage;
import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter;
import org.springframework.http.converter.xml.SourceHttpMessageConverter; import org.springframework.http.converter.xml.SourceHttpMessageConverter;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
@ -204,7 +205,7 @@ class FormHttpMessageConverterTests {
parameters.put("charset", UTF_8.name()); parameters.put("charset", UTF_8.name());
parameters.put("foo", "bar"); parameters.put("foo", "bar");
MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); StreamingMockHttpOutputMessage outputMessage = new StreamingMockHttpOutputMessage();
this.converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage); this.converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage);
final MediaType contentType = outputMessage.getHeaders().getContentType(); final MediaType contentType = outputMessage.getHeaders().getContentType();
@ -248,6 +249,8 @@ class FormHttpMessageConverterTests {
item = items.get(5); item = items.get(5);
assertThat(item.getFieldName()).isEqualTo("json"); assertThat(item.getFieldName()).isEqualTo("json");
assertThat(item.getContentType()).isEqualTo("application/json"); assertThat(item.getContentType()).isEqualTo("application/json");
assertThat(outputMessage.wasRepeatable()).isTrue();
} }
@Test @Test
@ -286,7 +289,7 @@ class FormHttpMessageConverterTests {
parameters.put("charset", UTF_8.name()); parameters.put("charset", UTF_8.name());
parameters.put("foo", "bar"); parameters.put("foo", "bar");
MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); StreamingMockHttpOutputMessage outputMessage = new StreamingMockHttpOutputMessage();
this.converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage); this.converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage);
final MediaType contentType = outputMessage.getHeaders().getContentType(); final MediaType contentType = outputMessage.getHeaders().getContentType();
@ -330,6 +333,8 @@ class FormHttpMessageConverterTests {
item = items.get(5); item = items.get(5);
assertThat(item.getFieldName()).isEqualTo("xml"); assertThat(item.getFieldName()).isEqualTo("xml");
assertThat(item.getContentType()).isEqualTo("text/xml"); assertThat(item.getContentType()).isEqualTo("text/xml");
assertThat(outputMessage.wasRepeatable()).isFalse();
} }
@Test // SPR-13309 @Test // SPR-13309
@ -444,6 +449,27 @@ class FormHttpMessageConverterTests {
} }
private static class StreamingMockHttpOutputMessage extends MockHttpOutputMessage implements StreamingHttpOutputMessage {
private boolean repeatable;
public boolean wasRepeatable() {
return this.repeatable;
}
@Override
public void setBody(Body body) {
try {
this.repeatable = body.repeatable();
body.writeTo(getBody());
}
catch (IOException ex) {
throw new RuntimeException(ex);
}
}
}
private static class MockHttpOutputMessageRequestContext implements UploadContext { private static class MockHttpOutputMessageRequestContext implements UploadContext {
private final MockHttpOutputMessage outputMessage; private final MockHttpOutputMessage outputMessage;