diff --git a/build.gradle b/build.gradle index d31ee6d629..078e0ff53d 100644 --- a/build.gradle +++ b/build.gradle @@ -89,6 +89,7 @@ configure(allprojects) { project -> ext.servletVersion = "3.1.0" ext.slf4jVersion = "1.7.25" ext.snakeyamlVersion = "1.18" + ext.nioMultipartVersion = "1.0.2" ext.testngVersion = "6.11" ext.tiles3Version = "3.0.7" ext.tomcatVersion = "8.5.14" @@ -747,7 +748,7 @@ project("spring-web") { optional("javax.xml.bind:jaxb-api:${jaxbVersion}") optional("javax.xml.ws:jaxws-api:${jaxwsVersion}") optional("javax.mail:javax.mail-api:${javamailVersion}") - optional("org.synchronoss.cloud:nio-multipart-parser:1.0.2") + optional("org.synchronoss.cloud:nio-multipart-parser:${nioMultipartVersion}") optional("org.jetbrains.kotlin:kotlin-stdlib-jre8:${kotlinVersion}") testCompile(project(":spring-context-support")) // for JafMediaTypeFactory testCompile("io.projectreactor.addons:reactor-test") @@ -839,6 +840,10 @@ project("spring-webflux") { testRuntime("org.jetbrains.kotlin:kotlin-compiler:${kotlinVersion}") testCompile("org.jetbrains.kotlin:kotlin-script-runtime:${kotlinVersion}") testRuntime("org.jetbrains.kotlin:kotlin-script-util:${kotlinVersion}") + testRuntime("org.synchronoss.cloud:nio-multipart-parser:${nioMultipartVersion}") + testRuntime("com.sun.mail:javax.mail:${javamailVersion}") + testRuntime("com.sun.xml.bind:jaxb-core:${jaxbVersion}") + testRuntime("com.sun.xml.bind:jaxb-impl:${jaxbVersion}") } if (JavaVersion.current().java9Compatible) { diff --git a/spring-web/src/main/java/org/springframework/http/codec/DefaultClientCodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/DefaultClientCodecConfigurer.java index a97c8fa3eb..a6c055154e 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/DefaultClientCodecConfigurer.java +++ b/spring-web/src/main/java/org/springframework/http/codec/DefaultClientCodecConfigurer.java @@ -21,6 +21,7 @@ import java.util.List; import org.springframework.core.codec.Decoder; import org.springframework.core.codec.StringDecoder; import org.springframework.http.codec.json.Jackson2JsonDecoder; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; /** * Default implementation of {@link ClientCodecConfigurer}. @@ -57,6 +58,7 @@ class DefaultClientCodecConfigurer extends DefaultCodecConfigurer implements Cli protected void addTypedWritersTo(List> result) { super.addTypedWritersTo(result); addWriterTo(result, FormHttpMessageWriter::new); + addWriterTo(result, MultipartHttpMessageWriter::new); } @Override diff --git a/spring-web/src/main/java/org/springframework/http/codec/DefaultServerCodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/DefaultServerCodecConfigurer.java index 6568333da2..aae8408af0 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/DefaultServerCodecConfigurer.java +++ b/spring-web/src/main/java/org/springframework/http/codec/DefaultServerCodecConfigurer.java @@ -21,6 +21,8 @@ import java.util.List; import org.springframework.core.codec.Encoder; import org.springframework.core.codec.StringDecoder; import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.codec.multipart.SynchronossMultipartHttpMessageReader; +import org.springframework.util.ClassUtils; /** * Default implementation of {@link ServerCodecConfigurer}. @@ -30,6 +32,11 @@ import org.springframework.http.codec.json.Jackson2JsonEncoder; */ class DefaultServerCodecConfigurer extends DefaultCodecConfigurer implements ServerCodecConfigurer { + static final boolean synchronossMultipartPresent = + ClassUtils.isPresent("org.synchronoss.cloud.nio.multipart.NioMultipartParser", + org.springframework.http.codec.DefaultCodecConfigurer.class.getClassLoader()); + + public DefaultServerCodecConfigurer() { super(new DefaultServerDefaultCodecsConfigurer()); } @@ -57,6 +64,9 @@ class DefaultServerCodecConfigurer extends DefaultCodecConfigurer implements Ser public void addTypedReadersTo(List> result) { super.addTypedReadersTo(result); addReaderTo(result, FormHttpMessageReader::new); + if (synchronossMultipartPresent) { + addReaderTo(result, SynchronossMultipartHttpMessageReader::new); + } } @Override diff --git a/spring-web/src/test/java/org/springframework/http/codec/ClientCodecConfigurerTests.java b/spring-web/src/test/java/org/springframework/http/codec/ClientCodecConfigurerTests.java index dae0beb971..3bee998273 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/ClientCodecConfigurerTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/ClientCodecConfigurerTests.java @@ -40,6 +40,7 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.MediaType; import org.springframework.http.codec.json.Jackson2JsonDecoder; import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; import org.springframework.http.codec.xml.Jaxb2XmlDecoder; import org.springframework.http.codec.xml.Jaxb2XmlEncoder; import org.springframework.util.MimeTypeUtils; @@ -76,13 +77,14 @@ public class ClientCodecConfigurerTests { @Test public void defaultWriters() throws Exception { List> writers = this.configurer.getWriters(); - assertEquals(9, writers.size()); + assertEquals(10, writers.size()); assertEquals(ByteArrayEncoder.class, getNextEncoder(writers).getClass()); assertEquals(ByteBufferEncoder.class, getNextEncoder(writers).getClass()); assertEquals(DataBufferEncoder.class, getNextEncoder(writers).getClass()); assertEquals(ResourceHttpMessageWriter.class, writers.get(index.getAndIncrement()).getClass()); assertStringEncoder(getNextEncoder(writers), true); assertEquals(FormHttpMessageWriter.class, writers.get(this.index.getAndIncrement()).getClass()); + assertEquals(MultipartHttpMessageWriter.class, writers.get(this.index.getAndIncrement()).getClass()); assertEquals(Jaxb2XmlEncoder.class, getNextEncoder(writers).getClass()); assertEquals(Jackson2JsonEncoder.class, getNextEncoder(writers).getClass()); assertStringEncoder(getNextEncoder(writers), false); diff --git a/spring-web/src/test/java/org/springframework/http/codec/ServerCodecConfigurerTests.java b/spring-web/src/test/java/org/springframework/http/codec/ServerCodecConfigurerTests.java index b7ff23807c..7e4f212ceb 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/ServerCodecConfigurerTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/ServerCodecConfigurerTests.java @@ -41,6 +41,7 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.MediaType; import org.springframework.http.codec.json.Jackson2JsonDecoder; import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.codec.multipart.SynchronossMultipartHttpMessageReader; import org.springframework.http.codec.xml.Jaxb2XmlDecoder; import org.springframework.http.codec.xml.Jaxb2XmlEncoder; import org.springframework.util.MimeTypeUtils; @@ -62,13 +63,14 @@ public class ServerCodecConfigurerTests { @Test public void defaultReaders() throws Exception { List> readers = this.configurer.getReaders(); - assertEquals(9, readers.size()); + assertEquals(10, readers.size()); assertEquals(ByteArrayDecoder.class, getNextDecoder(readers).getClass()); assertEquals(ByteBufferDecoder.class, getNextDecoder(readers).getClass()); assertEquals(DataBufferDecoder.class, getNextDecoder(readers).getClass()); assertEquals(ResourceDecoder.class, getNextDecoder(readers).getClass()); assertStringDecoder(getNextDecoder(readers), true); assertEquals(FormHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass()); + assertEquals(SynchronossMultipartHttpMessageReader.class, readers.get(this.index.getAndIncrement()).getClass()); assertEquals(Jaxb2XmlDecoder.class, getNextDecoder(readers).getClass()); assertEquals(Jackson2JsonDecoder.class, getNextDecoder(readers).getClass()); assertStringDecoder(getNextDecoder(readers), false); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java index 9b1add05c8..e4bbc30781 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java @@ -33,15 +33,19 @@ import org.springframework.http.HttpMessage; import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; +import static org.springframework.http.codec.multipart.MultipartHttpMessageReader.*; + /** * Implementations of {@link BodyExtractor} that read various bodies, such a reactive streams. * * @author Arjen Poutsma + * @author Sebastien Deleuze * @since 5.0 */ public abstract class BodyExtractors { @@ -135,6 +139,23 @@ public abstract class BodyExtractors { }; } + /** + * Return a {@code BodyExtractor} that reads form data into a {@link MultiValueMap}. + * @return a {@code BodyExtractor} that reads multipart data + */ + // Note that the returned BodyExtractor is parameterized to ServerHttpRequest, not + // ReactiveHttpInputMessage like other methods, since reading form data only typically happens on + // the server-side + public static BodyExtractor>, ServerHttpRequest> toMultipartData() { + return (serverRequest, context) -> { + HttpMessageReader> messageReader = + multipartMessageReader(context); + return context.serverResponse() + .map(serverResponse -> messageReader.readMono(MULTIPART_VALUE_TYPE, MULTIPART_VALUE_TYPE, serverRequest, serverResponse, context.hints())) + .orElseGet(() -> messageReader.readMono(MULTIPART_VALUE_TYPE, serverRequest, context.hints())); + }; + } + /** * Return a {@code BodyExtractor} that returns the body of the message as a {@link Flux} of * {@link DataBuffer}s. @@ -180,6 +201,17 @@ public abstract class BodyExtractors { MediaType.APPLICATION_FORM_URLENCODED_VALUE)); } + private static HttpMessageReader> multipartMessageReader(BodyExtractor.Context context) { + return context.messageReaders().get() + .filter(messageReader -> messageReader + .canRead(MULTIPART_VALUE_TYPE, MediaType.MULTIPART_FORM_DATA)) + .findFirst() + .map(BodyExtractors::>cast) + .orElseThrow(() -> new IllegalStateException( + "Could not find HttpMessageReader that supports " + + MediaType.MULTIPART_FORM_DATA)); + } + private static MediaType contentType(HttpMessage message) { MediaType result = message.getHeaders().getContentType(); return result != null ? result : MediaType.APPLICATION_OCTET_STREAM; diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java index cd672bb998..f637b4c681 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java @@ -33,11 +33,15 @@ import org.springframework.http.ReactiveHttpOutputMessage; import org.springframework.http.client.reactive.ClientHttpRequest; import org.springframework.http.codec.HttpMessageWriter; import org.springframework.http.codec.ServerSentEvent; +import org.springframework.http.codec.multipart.MultipartHttpMessageReader; +import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; +import static org.springframework.http.codec.multipart.MultipartHttpMessageReader.*; + /** * Implementations of {@link BodyInserter} that write various bodies, such a reactive streams, * server-sent events, resources, etc. @@ -243,6 +247,27 @@ public abstract class BodyInserters { }; } + /** + * Return a {@code BodyInserter} that writes the given {@code MultiValueMap} as Multipart + * data. + * @param multipartData the form data to write to the output message + * @return a {@code BodyInserter} that writes form data + */ + // Note that the returned BodyInserter is parameterized to ClientHttpRequest, not + // ReactiveHttpOutputMessage like other methods, since sending form data only typically happens + // on the server-side + public static BodyInserter, ClientHttpRequest> fromMultipartData( + MultiValueMap multipartData) { + + Assert.notNull(multipartData, "'multipartData' must not be null"); + return (outputMessage, context) -> { + HttpMessageWriter> messageWriter = + findMessageWriter(context, MULTIPART_VALUE_TYPE, MediaType.MULTIPART_FORM_DATA); + return messageWriter.write(Mono.just(multipartData), FORM_TYPE, + MediaType.MULTIPART_FORM_DATA, outputMessage, context.hints()); + }; + } + /** * Return a {@code BodyInserter} that writes the given {@code Publisher} to the body. * @param publisher the data buffer publisher to write diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/config/DelegatingWebFluxConfigurationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/config/DelegatingWebFluxConfigurationTests.java index 8b2ad176e8..eff558f6b7 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/config/DelegatingWebFluxConfigurationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/config/DelegatingWebFluxConfigurationTests.java @@ -103,7 +103,7 @@ public class DelegatingWebFluxConfigurationTests { verify(webFluxConfigurer).configureArgumentResolvers(any()); assertSame(formatterRegistry.getValue(), initializerConversionService); - assertEquals(9, codecsConfigurer.getValue().getReaders().size()); + assertEquals(10, codecsConfigurer.getValue().getReaders().size()); } @Test diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/config/WebFluxConfigurationSupportTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/config/WebFluxConfigurationSupportTests.java index eeed6f7918..ae87f4f25e 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/config/WebFluxConfigurationSupportTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/config/WebFluxConfigurationSupportTests.java @@ -127,7 +127,7 @@ public class WebFluxConfigurationSupportTests { assertNotNull(adapter); List> readers = adapter.getMessageCodecConfigurer().getReaders(); - assertEquals(9, readers.size()); + assertEquals(10, readers.size()); assertHasMessageReader(readers, forClass(byte[].class), APPLICATION_OCTET_STREAM); assertHasMessageReader(readers, forClass(ByteBuffer.class), APPLICATION_OCTET_STREAM); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java new file mode 100644 index 0000000000..a42acfc052 --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import java.util.Map; + +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.multipart.Part; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.AbstractRouterFunctionIntegrationTests; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; + +import static org.junit.Assert.assertEquals; + +import static org.springframework.web.reactive.function.server.RequestPredicates.POST; +import static org.springframework.web.reactive.function.server.RouterFunctions.route; + +public class MultipartIntegrationTests extends AbstractRouterFunctionIntegrationTests { + + private final WebClient webClient = WebClient.create(); + + @Test + public void multipart() { + Mono result = webClient + .post() + .uri("http://localhost:" + this.port + "/") + .contentType(MediaType.MULTIPART_FORM_DATA) + .body(BodyInserters.fromMultipartData(generateBody())) + .exchange(); + + StepVerifier + .create(result) + .consumeNextWith(response -> { + assertEquals(HttpStatus.OK, response.statusCode()); + }) + .verifyComplete(); + } + + private MultiValueMap generateBody() { + HttpHeaders fooHeaders = new HttpHeaders(); + fooHeaders.setContentType(MediaType.TEXT_PLAIN); + ClassPathResource fooResource = new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"); + HttpEntity fooPart = new HttpEntity<>(fooResource, fooHeaders); + HttpEntity barPart = new HttpEntity<>("bar"); + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("fooPart", fooPart); + parts.add("barPart", barPart); + return parts; + } + + @Override + protected RouterFunction routerFunction() { + MultipartHandler multipartHandler = new MultipartHandler(); + return route(POST("/"), multipartHandler::handle); + } + + private static class MultipartHandler { + + public Mono handle(ServerRequest request) { + return request + .body(BodyExtractors.toMultipartData()) + .flatMap(map -> { + Map parts = map.toSingleValueMap(); + try { + assertEquals(2, parts.size()); + assertEquals("foo.txt", parts.get("fooPart").getFilename().get()); + assertEquals("bar", parts.get("barPart").getContentAsString().block()); + } + catch(Exception e) { + return Mono.error(e); + } + return ServerResponse.ok().build(); + }); + } + } + +}