Add support for Flux<Part> in BodyExtractors

This commit adds a `toParts` method in `BodyExtractors`, returning a
BodyExtractor<Part>.
This commit is contained in:
Arjen Poutsma 2017-05-04 12:21:48 +02:00
parent 1f5eaf20b0
commit 4525c6a537
3 changed files with 144 additions and 30 deletions

View File

@ -48,12 +48,14 @@ import org.springframework.util.MultiValueMap;
*/ */
public abstract class BodyExtractors { public abstract class BodyExtractors {
private static final ResolvableType FORM_TYPE = private static final ResolvableType FORM_MAP_TYPE =
ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class);
private static final ResolvableType MULTIPART_TYPE = ResolvableType.forClassWithGenerics( private static final ResolvableType MULTIPART_MAP_TYPE = ResolvableType.forClassWithGenerics(
MultiValueMap.class, String.class, Part.class); MultiValueMap.class, String.class, Part.class);
private static final ResolvableType PART_TYPE = ResolvableType.forClass(Part.class);
/** /**
* Return a {@code BodyExtractor} that reads into a Reactor {@link Mono}. * Return a {@code BodyExtractor} that reads into a Reactor {@link Mono}.
@ -133,15 +135,16 @@ public abstract class BodyExtractors {
public static BodyExtractor<Mono<MultiValueMap<String, String>>, ServerHttpRequest> toFormData() { public static BodyExtractor<Mono<MultiValueMap<String, String>>, ServerHttpRequest> toFormData() {
return (serverRequest, context) -> { return (serverRequest, context) -> {
HttpMessageReader<MultiValueMap<String, String>> messageReader = HttpMessageReader<MultiValueMap<String, String>> messageReader =
formMessageReader(context); messageReader(FORM_MAP_TYPE, MediaType.APPLICATION_FORM_URLENCODED, context);
return context.serverResponse() return context.serverResponse()
.map(serverResponse -> messageReader.readMono(FORM_TYPE, FORM_TYPE, serverRequest, serverResponse, context.hints())) .map(serverResponse -> messageReader.readMono(FORM_MAP_TYPE, FORM_MAP_TYPE, serverRequest, serverResponse, context.hints()))
.orElseGet(() -> messageReader.readMono(FORM_TYPE, serverRequest, context.hints())); .orElseGet(() -> messageReader.readMono(FORM_MAP_TYPE, serverRequest, context.hints()));
}; };
} }
/** /**
* Return a {@code BodyExtractor} that reads form data into a {@link MultiValueMap}. * Return a {@code BodyExtractor} that reads multipart (i.e. file upload) form data into a
* {@link MultiValueMap}.
* @return a {@code BodyExtractor} that reads multipart data * @return a {@code BodyExtractor} that reads multipart data
*/ */
// Note that the returned BodyExtractor is parameterized to ServerHttpRequest, not // Note that the returned BodyExtractor is parameterized to ServerHttpRequest, not
@ -150,10 +153,29 @@ public abstract class BodyExtractors {
public static BodyExtractor<Mono<MultiValueMap<String, Part>>, ServerHttpRequest> toMultipartData() { public static BodyExtractor<Mono<MultiValueMap<String, Part>>, ServerHttpRequest> toMultipartData() {
return (serverRequest, context) -> { return (serverRequest, context) -> {
HttpMessageReader<MultiValueMap<String, Part>> messageReader = HttpMessageReader<MultiValueMap<String, Part>> messageReader =
multipartMessageReader(context); messageReader(MULTIPART_MAP_TYPE, MediaType.MULTIPART_FORM_DATA, context);
return context.serverResponse() return context.serverResponse()
.map(serverResponse -> messageReader.readMono(MULTIPART_TYPE, MULTIPART_TYPE, serverRequest, serverResponse, context.hints())) .map(serverResponse -> messageReader.readMono(MULTIPART_MAP_TYPE,
.orElseGet(() -> messageReader.readMono(MULTIPART_TYPE, serverRequest, context.hints())); MULTIPART_MAP_TYPE, serverRequest, serverResponse, context.hints()))
.orElseGet(() -> messageReader.readMono(MULTIPART_MAP_TYPE, serverRequest, context.hints()));
};
}
/**
* Return a {@code BodyExtractor} that reads multipart (i.e. file upload) form data into a
* {@link MultiValueMap}.
* @return a {@code BodyExtractor} that reads multipart data
*/
// Note that the returned BodyExtractor is parameterized to ServerHttpRequest, not
// ReactiveHttpInputMessage like other methods, since reading form data only typically happens on
// the server-side
public static BodyExtractor<Flux<Part>, ServerHttpRequest> toParts() {
return (serverRequest, context) -> {
HttpMessageReader<Part> messageReader =
messageReader(PART_TYPE, MediaType.MULTIPART_FORM_DATA, context);
return context.serverResponse()
.map(serverResponse -> messageReader.read(PART_TYPE, PART_TYPE, serverRequest, serverResponse, context.hints()))
.orElseGet(() -> messageReader.read(PART_TYPE, serverRequest, context.hints()));
}; };
} }
@ -191,26 +213,15 @@ public abstract class BodyExtractors {
}); });
} }
private static HttpMessageReader<MultiValueMap<String, String>> formMessageReader(BodyExtractor.Context context) { private static <T> HttpMessageReader<T> messageReader(ResolvableType elementType,
MediaType mediaType, BodyExtractor.Context context) {
return context.messageReaders().get() return context.messageReaders().get()
.filter(messageReader -> messageReader .filter(messageReader -> messageReader.canRead(elementType, mediaType))
.canRead(FORM_TYPE, MediaType.APPLICATION_FORM_URLENCODED))
.findFirst() .findFirst()
.map(BodyExtractors::<MultiValueMap<String, String>>cast) .map(BodyExtractors::<T>cast)
.orElseThrow(() -> new IllegalStateException( .orElseThrow(() -> new IllegalStateException(
"Could not find HttpMessageReader that supports " + "Could not find HttpMessageReader that supports \"" + mediaType +
MediaType.APPLICATION_FORM_URLENCODED_VALUE)); "\" and \"" + elementType + "\""));
}
private static HttpMessageReader<MultiValueMap<String, Part>> multipartMessageReader(BodyExtractor.Context context) {
return context.messageReaders().get()
.filter(messageReader -> messageReader
.canRead(MULTIPART_TYPE, MediaType.MULTIPART_FORM_DATA))
.findFirst()
.map(BodyExtractors::<MultiValueMap<String, Part>>cast)
.orElseThrow(() -> new IllegalStateException(
"Could not find HttpMessageReader that supports " +
MediaType.MULTIPART_FORM_DATA));
} }
private static MediaType contentType(HttpMessage message) { private static MediaType contentType(HttpMessage message) {

View File

@ -36,6 +36,8 @@ import reactor.test.StepVerifier;
import org.springframework.core.codec.ByteBufferDecoder; import org.springframework.core.codec.ByteBufferDecoder;
import org.springframework.core.codec.StringDecoder; import org.springframework.core.codec.StringDecoder;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DefaultDataBuffer; import org.springframework.core.io.buffer.DefaultDataBuffer;
import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory;
@ -45,10 +47,16 @@ import org.springframework.http.codec.DecoderHttpMessageReader;
import org.springframework.http.codec.FormHttpMessageReader; import org.springframework.http.codec.FormHttpMessageReader;
import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.json.Jackson2JsonDecoder; import org.springframework.http.codec.json.Jackson2JsonDecoder;
import org.springframework.http.codec.multipart.FilePart;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.codec.multipart.MultipartHttpMessageReader;
import org.springframework.http.codec.multipart.Part;
import org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader;
import org.springframework.http.codec.xml.Jaxb2XmlDecoder; import org.springframework.http.codec.xml.Jaxb2XmlDecoder;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -72,6 +80,11 @@ public class BodyExtractorsTests {
messageReaders.add(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true))); messageReaders.add(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true)));
messageReaders.add(new DecoderHttpMessageReader<>(new Jaxb2XmlDecoder())); messageReaders.add(new DecoderHttpMessageReader<>(new Jaxb2XmlDecoder()));
messageReaders.add(new DecoderHttpMessageReader<>(new Jackson2JsonDecoder())); messageReaders.add(new DecoderHttpMessageReader<>(new Jackson2JsonDecoder()));
messageReaders.add(new FormHttpMessageReader());
SynchronossPartHttpMessageReader partReader = new SynchronossPartHttpMessageReader();
messageReaders.add(partReader);
messageReaders.add(new MultipartHttpMessageReader(partReader));
messageReaders.add(new FormHttpMessageReader()); messageReaders.add(new FormHttpMessageReader());
this.context = new BodyExtractor.Context() { this.context = new BodyExtractor.Context() {
@ -249,6 +262,64 @@ public class BodyExtractorsTests {
.verify(); .verify();
} }
@Test
public void toParts() throws Exception {
BodyExtractor<Flux<Part>, ServerHttpRequest> extractor = BodyExtractors.toParts();
String bodyContents = "-----------------------------9051914041544843365972754266\r\n" +
"Content-Disposition: form-data; name=\"text\"\r\n" +
"\r\n" +
"text default\r\n" +
"-----------------------------9051914041544843365972754266\r\n" +
"Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r\n" +
"Content-Type: text/plain\r\n" +
"\r\n" +
"Content of a.txt.\r\n" +
"\r\n" +
"-----------------------------9051914041544843365972754266\r\n" +
"Content-Disposition: form-data; name=\"file2\"; filename=\"a.html\"\r\n" +
"Content-Type: text/html\r\n" +
"\r\n" +
"<!DOCTYPE html><title>Content of a.html.</title>\r\n" +
"\r\n" +
"-----------------------------9051914041544843365972754266--\r\n";
DefaultDataBufferFactory factory = new DefaultDataBufferFactory();
DefaultDataBuffer dataBuffer =
factory.wrap(ByteBuffer.wrap(bodyContents.getBytes(StandardCharsets.UTF_8)));
Flux<DataBuffer> body = Flux.just(dataBuffer);
MockServerHttpRequest request = MockServerHttpRequest.post("/")
.header("Content-Type", "multipart/form-data; boundary=---------------------------9051914041544843365972754266")
.body(body);
Flux<Part> result = extractor.extract(request, this.context);
StepVerifier.create(result)
.consumeNextWith(part -> {
assertEquals("text", part.getName());
assertTrue(part instanceof FormFieldPart);
FormFieldPart formFieldPart = (FormFieldPart) part;
assertEquals("text default", formFieldPart.getValue());
})
.consumeNextWith(part -> {
assertEquals("file1", part.getName());
assertTrue(part instanceof FilePart);
FilePart filePart = (FilePart) part;
assertEquals("a.txt", filePart.getFilename());
assertEquals(MediaType.TEXT_PLAIN, filePart.getHeaders().getContentType());
})
.consumeNextWith(part -> {
assertEquals("file2", part.getName());
assertTrue(part instanceof FilePart);
FilePart filePart = (FilePart) part;
assertEquals("a.html", filePart.getFilename());
assertEquals(MediaType.TEXT_HTML, filePart.getHeaders().getContentType());
})
.expectComplete()
.verify();
}
@Test @Test
public void toDataBuffers() throws Exception { public void toDataBuffers() throws Exception {
BodyExtractor<Flux<DataBuffer>, ReactiveHttpInputMessage> extractor = BodyExtractors.toDataBuffers(); BodyExtractor<Flux<DataBuffer>, ReactiveHttpInputMessage> extractor = BodyExtractors.toDataBuffers();

View File

@ -16,6 +16,7 @@
package org.springframework.web.reactive.function; package org.springframework.web.reactive.function;
import java.util.List;
import java.util.Map; import java.util.Map;
import org.junit.Test; import org.junit.Test;
@ -48,10 +49,25 @@ public class MultipartIntegrationTests extends AbstractRouterFunctionIntegration
private final WebClient webClient = WebClient.create(); private final WebClient webClient = WebClient.create();
@Test @Test
public void multipart() { public void multipartData() {
Mono<ClientResponse> result = webClient Mono<ClientResponse> result = webClient
.post() .post()
.uri("http://localhost:" + this.port + "/") .uri("http://localhost:" + this.port + "/multipartData")
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(BodyInserters.fromMultipartData(generateBody()))
.exchange();
StepVerifier
.create(result)
.consumeNextWith(response -> assertEquals(HttpStatus.OK, response.statusCode()))
.verifyComplete();
}
@Test
public void parts() {
Mono<ClientResponse> result = webClient
.post()
.uri("http://localhost:" + this.port + "/parts")
.contentType(MediaType.MULTIPART_FORM_DATA) .contentType(MediaType.MULTIPART_FORM_DATA)
.body(BodyInserters.fromMultipartData(generateBody())) .body(BodyInserters.fromMultipartData(generateBody()))
.exchange(); .exchange();
@ -77,12 +93,13 @@ public class MultipartIntegrationTests extends AbstractRouterFunctionIntegration
@Override @Override
protected RouterFunction<ServerResponse> routerFunction() { protected RouterFunction<ServerResponse> routerFunction() {
MultipartHandler multipartHandler = new MultipartHandler(); MultipartHandler multipartHandler = new MultipartHandler();
return route(POST("/"), multipartHandler::handle); return route(POST("/multipartData"), multipartHandler::multipartData)
.andRoute(POST("/parts"), multipartHandler::parts);
} }
private static class MultipartHandler { private static class MultipartHandler {
public Mono<ServerResponse> handle(ServerRequest request) { public Mono<ServerResponse> multipartData(ServerRequest request) {
return request return request
.body(BodyExtractors.toMultipartData()) .body(BodyExtractors.toMultipartData())
.flatMap(map -> { .flatMap(map -> {
@ -98,6 +115,21 @@ public class MultipartIntegrationTests extends AbstractRouterFunctionIntegration
return ServerResponse.ok().build(); return ServerResponse.ok().build();
}); });
} }
public Mono<ServerResponse> parts(ServerRequest request) {
return request.body(BodyExtractors.toParts()).collectList()
.flatMap(parts -> {
try {
assertEquals(2, parts.size());
assertEquals("foo.txt", ((FilePart) parts.get(0)).getFilename());
assertEquals("bar", ((FormFieldPart) parts.get(1)).getValue());
}
catch(Exception e) {
return Mono.error(e);
}
return ServerResponse.ok().build();
});
}
} }
} }