diff --git a/spring-web/src/main/java/org/springframework/web/server/ServerWebExchange.java b/spring-web/src/main/java/org/springframework/web/server/ServerWebExchange.java index ebd06b5a71e..9b026b30a65 100644 --- a/spring-web/src/main/java/org/springframework/web/server/ServerWebExchange.java +++ b/spring-web/src/main/java/org/springframework/web/server/ServerWebExchange.java @@ -24,6 +24,7 @@ import java.util.function.Consumer; import reactor.core.publisher.Mono; +import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.util.MultiValueMap; @@ -82,6 +83,12 @@ public interface ServerWebExchange { */ Mono> getFormData(); + /** + * Return the form parts from the body of the request or an empty {@code Mono} + * if the Content-Type is not "multipart/form-data". + */ + Mono> getMultipartData(); + /** * Return a combined map that represents both * {@link ServerHttpRequest#getQueryParams()} and {@link #getFormData()} diff --git a/spring-web/src/main/java/org/springframework/web/server/ServerWebExchangeDecorator.java b/spring-web/src/main/java/org/springframework/web/server/ServerWebExchangeDecorator.java index 5e2e32322a5..78a774c272c 100644 --- a/spring-web/src/main/java/org/springframework/web/server/ServerWebExchangeDecorator.java +++ b/spring-web/src/main/java/org/springframework/web/server/ServerWebExchangeDecorator.java @@ -22,6 +22,7 @@ import java.util.Optional; import reactor.core.publisher.Mono; +import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.util.Assert; @@ -93,6 +94,11 @@ public class ServerWebExchangeDecorator implements ServerWebExchange { return getDelegate().getFormData(); } + @Override + public Mono> getMultipartData() { + return getDelegate().getMultipartData(); + } + @Override public Mono> getRequestParams() { return getDelegate().getRequestParams(); diff --git a/spring-web/src/main/java/org/springframework/web/server/adapter/DefaultServerWebExchange.java b/spring-web/src/main/java/org/springframework/web/server/adapter/DefaultServerWebExchange.java index d96c248f013..46cd0da88fb 100644 --- a/spring-web/src/main/java/org/springframework/web/server/adapter/DefaultServerWebExchange.java +++ b/spring-web/src/main/java/org/springframework/web/server/adapter/DefaultServerWebExchange.java @@ -36,6 +36,7 @@ import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.MediaType; import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.util.Assert; @@ -48,6 +49,7 @@ import org.springframework.web.server.WebSession; import org.springframework.web.server.session.WebSessionManager; import static org.springframework.http.MediaType.*; +import static org.springframework.http.codec.multipart.MultipartHttpMessageReader.*; /** * Default implementation of {@link ServerWebExchange}. @@ -66,6 +68,10 @@ public class DefaultServerWebExchange implements ServerWebExchange { Mono.just(CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap(0))) .cache(); + private static final Mono> EMPTY_MULTIPART_DATA = + Mono.just(CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap(0))) + .cache(); + private final ServerHttpRequest request; @@ -77,6 +83,8 @@ public class DefaultServerWebExchange implements ServerWebExchange { private final Mono> formDataMono; + private final Mono> multipartDataMono; + private final Mono> requestParamsMono; private volatile boolean notModified; @@ -97,6 +105,7 @@ public class DefaultServerWebExchange implements ServerWebExchange { this.response = response; this.sessionMono = sessionManager.getSession(this).cache(); this.formDataMono = initFormData(request, codecConfigurer); + this.multipartDataMono = initMultipartData(request, codecConfigurer); this.requestParamsMono = initRequestParams(request, this.formDataMono); } @@ -126,6 +135,31 @@ public class DefaultServerWebExchange implements ServerWebExchange { return EMPTY_FORM_DATA; } + @SuppressWarnings("unchecked") + private static Mono> initMultipartData( + ServerHttpRequest request, ServerCodecConfigurer codecConfigurer) { + + MediaType contentType; + try { + contentType = request.getHeaders().getContentType(); + if (MULTIPART_FORM_DATA.isCompatibleWith(contentType)) { + return ((HttpMessageReader>)codecConfigurer + .getReaders() + .stream() + .filter(messageReader -> messageReader.canRead(MULTIPART_VALUE_TYPE, MULTIPART_FORM_DATA)) + .findFirst() + .orElseThrow(() -> new IllegalStateException("Could not find HttpMessageReader that supports " + MULTIPART_FORM_DATA))) + .readMono(FORM_DATA_VALUE_TYPE, request, Collections.emptyMap()) + .switchIfEmpty(EMPTY_MULTIPART_DATA) + .cache(); + } + } + catch (InvalidMediaTypeException ex) { + // Ignore + } + return EMPTY_MULTIPART_DATA; + } + private static Mono> initRequestParams( ServerHttpRequest request, Mono> formDataMono) { @@ -184,6 +218,11 @@ public class DefaultServerWebExchange implements ServerWebExchange { return this.formDataMono; } + @Override + public Mono> getMultipartData() { + return this.multipartDataMono; + } + @Override public Mono> getRequestParams() { return this.requestParamsMono; diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/MultipartIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/MultipartIntegrationTests.java new file mode 100644 index 00000000000..d7f8bb80d37 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/MultipartIntegrationTests.java @@ -0,0 +1,120 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; +import java.util.Optional; + +import static org.junit.Assert.*; +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestTemplate; +import org.springframework.http.codec.multipart.Part; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebHandler; +import org.springframework.web.server.adapter.HttpWebHandlerAdapter; + +public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + @Override + protected HttpHandler createHttpHandler() { + HttpWebHandlerAdapter handler = new HttpWebHandlerAdapter(new CheckRequestHandler()); + return handler; + } + + @Test + public void getFormParts() throws Exception { + RestTemplate restTemplate = new RestTemplate(); + RequestEntity> request = RequestEntity + .post(new URI("http://localhost:" + port + "/form-parts")) + .contentType(MediaType.MULTIPART_FORM_DATA) + .body(generateBody()); + ResponseEntity response = restTemplate.exchange(request, Void.class); + assertEquals(HttpStatus.OK, response.getStatusCode()); + } + + private MultiValueMap generateBody() { + HttpHeaders fooHeaders = new HttpHeaders(); + fooHeaders.setContentType(MediaType.TEXT_PLAIN); + ClassPathResource fooResource = new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"); + HttpEntity fooPart = new HttpEntity<>(fooResource, fooHeaders); + HttpEntity barPart = new HttpEntity<>("bar"); + MultiValueMap parts = new LinkedMultiValueMap<>(); + parts.add("fooPart", fooPart); + parts.add("barPart", barPart); + return parts; + } + + public static class CheckRequestHandler implements WebHandler { + + @Override + public Mono handle(ServerWebExchange exchange) { + + if (exchange.getRequest().getURI().getPath().equals("/form-parts")) { + return assertGetFormParts(exchange); + } + return Mono.error(new AssertionError()); + } + + private Mono assertGetFormParts(ServerWebExchange exchange) { + return exchange + .getMultipartData() + .doOnNext(parts -> { + assertEquals(2, parts.size()); + assertTrue(parts.containsKey("fooPart")); + assertFooPart(parts.getFirst("fooPart")); + assertTrue(parts.containsKey("barPart")); + assertBarPart(parts.getFirst("barPart")); + }) + .then(); + } + + private void assertFooPart(Part part) { + assertEquals("fooPart", part.getName()); + Optional filename = part.getFilename(); + assertTrue(filename.isPresent()); + assertEquals("foo.txt", filename.get()); + DataBuffer buffer = part + .getContent() + .reduce((s1, s2) -> s1.write(s2)) + .block(); + assertEquals(12, buffer.readableByteCount()); + byte[] byteContent = new byte[12]; + buffer.read(byteContent); + assertEquals("Lorem\nIpsum\n", new String(byteContent)); + } + + private void assertBarPart(Part part) { + assertEquals("barPart", part.getName()); + Optional filename = part.getFilename(); + assertFalse(filename.isPresent()); + assertEquals("bar", part.getContentAsString().block()); + } + } + +} \ No newline at end of file