Add @Request(Param/Part) support for multipart requests

Issue: SPR-14546
This commit is contained in:
Sebastien Deleuze 2017-04-28 16:12:15 +02:00
parent 4bfd04b3c5
commit f2caaa9195
7 changed files with 316 additions and 4 deletions

View File

@ -262,6 +262,7 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
try {
// Ensure form data is parsed for "params" conditions...
return exchange.getRequestParams()
.then(exchange.getMultipartData())
.then(Mono.defer(() -> {
HandlerMethod handlerMethod = null;
try {

View File

@ -133,6 +133,7 @@ class ControllerMethodResolver {
// Annotation-based...
registrar.add(new RequestParamMethodArgumentResolver(beanFactory, reactiveRegistry, false));
registrar.add(new RequestPartMethodArgumentResolver(beanFactory, reactiveRegistry, false));
registrar.add(new RequestParamMapMethodArgumentResolver(reactiveRegistry));
registrar.add(new PathVariableMethodArgumentResolver(beanFactory, reactiveRegistry));
registrar.add(new PathVariableMapMethodArgumentResolver(reactiveRegistry));

View File

@ -21,6 +21,8 @@ import java.util.Optional;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.http.codec.multipart.Part;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
@ -42,6 +44,7 @@ import org.springframework.web.server.ServerWebExchange;
* request parameters have multiple values.
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
* @since 5.0
* @see RequestParamMethodArgumentResolver
*/
@ -67,12 +70,17 @@ public class RequestParamMapMethodArgumentResolver extends HandlerMethodArgument
public Optional<Object> resolveArgumentValue(MethodParameter methodParameter,
BindingContext context, ServerWebExchange exchange) {
Class<?> paramType = methodParameter.getParameterType();
boolean isMultiValueMap = MultiValueMap.class.isAssignableFrom(paramType);
ResolvableType paramType = ResolvableType.forType(methodParameter.getGenericParameterType());
boolean isMultiValueMap = MultiValueMap.class.isAssignableFrom(paramType.getRawClass());
if (paramType.getGeneric(1).getRawClass() == Part.class) {
MultiValueMap<String, Part> requestParts = exchange.getMultipartData().subscribe().peek();
Assert.notNull(requestParts, "Expected multipart data (if any) to be parsed.");
return Optional.of(isMultiValueMap ? requestParts : requestParts.toSingleValueMap());
}
MultiValueMap<String, String> requestParams = exchange.getRequestParams().subscribe().peek();
Assert.notNull(requestParams, "Expected form data (if any) to be parsed.");
return Optional.of(isMultiValueMap ? requestParams : requestParams.toSingleValueMap());
}

View File

@ -25,6 +25,7 @@ import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.codec.multipart.Part;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
@ -102,7 +103,7 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueSyncAr
protected Optional<Object> resolveNamedValue(String name, MethodParameter parameter,
ServerWebExchange exchange) {
List<String> paramValues = getRequestParams(exchange).get(name);
List<?> paramValues = parameter.getParameter().getType() == Part.class ? getMultipartData(exchange).get(name) : getRequestParams(exchange).get(name);
Object result = null;
if (paramValues != null) {
result = (paramValues.size() == 1 ? paramValues.get(0) : paramValues);
@ -116,6 +117,12 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueSyncAr
return params;
}
private MultiValueMap<String, Part> getMultipartData(ServerWebExchange exchange) {
MultiValueMap<String, Part> params = exchange.getMultipartData().subscribe().peek();
Assert.notNull(params, "Expected multipart data (if any) to be parsed.");
return params;
}
@Override
protected void handleMissingValue(String name, MethodParameter parameter, ServerWebExchange exchange) {
String type = parameter.getNestedParameterType().getSimpleName();

View File

@ -0,0 +1,128 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.result.method.annotation;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.codec.multipart.Part;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.bind.annotation.ValueConstants;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebInputException;
/**
* Resolver for method arguments annotated with @{@link RequestPart}.
*
* @author Sebastien Deleuze
* @since 5.0
* @see RequestParamMapMethodArgumentResolver
*/
public class RequestPartMethodArgumentResolver extends AbstractNamedValueSyncArgumentResolver {
private final boolean useDefaultResolution;
/**
* Class constructor with a default resolution mode flag.
* @param factory a bean factory used for resolving ${...} placeholder
* and #{...} SpEL expressions in default values, or {@code null} if default
* values are not expected to contain expressions
* @param registry for checking reactive type wrappers
* @param useDefaultResolution in default resolution mode a method argument
* that is a simple type, as defined in {@link BeanUtils#isSimpleProperty},
* is treated as a request parameter even if it isn't annotated, the
* request parameter name is derived from the method parameter name.
*/
public RequestPartMethodArgumentResolver(
ConfigurableBeanFactory factory, ReactiveAdapterRegistry registry, boolean useDefaultResolution) {
super(factory, registry);
this.useDefaultResolution = useDefaultResolution;
}
@Override
public boolean supportsParameter(MethodParameter param) {
if (checkAnnotatedParamNoReactiveWrapper(param, RequestPart.class, this::singleParam)) {
return true;
}
else if (this.useDefaultResolution) {
return checkParameterTypeNoReactiveWrapper(param, BeanUtils::isSimpleProperty) ||
BeanUtils.isSimpleProperty(param.nestedIfOptional().getNestedParameterType());
}
return false;
}
private boolean singleParam(RequestPart requestParam, Class<?> type) {
return !Map.class.isAssignableFrom(type) || StringUtils.hasText(requestParam.name());
}
@Override
protected NamedValueInfo createNamedValueInfo(MethodParameter parameter) {
RequestPart ann = parameter.getParameterAnnotation(RequestPart.class);
return (ann != null ? new RequestPartNamedValueInfo(ann) : new RequestPartNamedValueInfo());
}
@Override
protected Optional<Object> resolveNamedValue(String name, MethodParameter parameter,
ServerWebExchange exchange) {
List<?> paramValues = getMultipartData(exchange).get(name);
Object result = null;
if (paramValues != null) {
result = (paramValues.size() == 1 ? paramValues.get(0) : paramValues);
}
return Optional.ofNullable(result);
}
private MultiValueMap<String, Part> getMultipartData(ServerWebExchange exchange) {
MultiValueMap<String, Part> params = exchange.getMultipartData().subscribe().peek();
Assert.notNull(params, "Expected multipart data (if any) to be parsed.");
return params;
}
@Override
protected void handleMissingValue(String name, MethodParameter parameter, ServerWebExchange exchange) {
String type = parameter.getNestedParameterType().getSimpleName();
String reason = "Required " + type + " parameter '" + name + "' is not present";
throw new ServerWebInputException(reason, parameter);
}
private static class RequestPartNamedValueInfo extends NamedValueInfo {
RequestPartNamedValueInfo() {
super("", false, ValueConstants.DEFAULT_NONE);
}
RequestPartNamedValueInfo(RequestPart annotation) {
super(annotation.name(), annotation.required(), ValueConstants.DEFAULT_NONE);
}
}
}

View File

@ -91,6 +91,7 @@ public class ControllerMethodResolverTests {
AtomicInteger index = new AtomicInteger(-1);
assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass());
@ -129,6 +130,7 @@ public class ControllerMethodResolverTests {
AtomicInteger index = new AtomicInteger(-1);
assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass());
@ -165,6 +167,7 @@ public class ControllerMethodResolverTests {
AtomicInteger index = new AtomicInteger(-1);
assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass());
@ -195,6 +198,7 @@ public class ControllerMethodResolverTests {
AtomicInteger index = new AtomicInteger(-1);
assertEquals(RequestParamMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestPartMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(RequestParamMapMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PathVariableMapMethodArgumentResolver.class, next(resolvers, index).getClass());

View File

@ -0,0 +1,163 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.result.method.annotation;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.multipart.Part;
import org.springframework.http.server.reactive.AbstractHttpHandlerIntegrationTests;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.DispatcherHandler;
import org.springframework.web.reactive.config.EnableWebFlux;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import static org.junit.Assert.assertEquals;
public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTests {
private AnnotationConfigApplicationContext wac;
private WebClient webClient;
@Override
@Before
public void setup() throws Exception {
super.setup();
this.webClient = WebClient.create("http://localhost:" + this.port);
}
@Override
protected HttpHandler createHttpHandler() {
this.wac = new AnnotationConfigApplicationContext();
this.wac.register(TestConfiguration.class);
this.wac.refresh();
return WebHttpHandlerBuilder.webHandler(new DispatcherHandler(this.wac)).build();
}
@Test
public void map() {
test("/map");
}
@Test
public void multiValueMap() {
test("/multivaluemap");
}
@Test
public void partParam() {
test("/partparam");
}
@Test
public void part() {
test("/part");
}
private void test(String uri) {
Mono<ClientResponse> result = webClient
.post()
.uri(uri)
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(BodyInserters.fromMultipartData(generateBody()))
.exchange();
StepVerifier
.create(result)
.consumeNextWith(response -> assertEquals(HttpStatus.OK, response.statusCode()))
.verifyComplete();
}
private MultiValueMap<String, Object> generateBody() {
HttpHeaders fooHeaders = new HttpHeaders();
fooHeaders.setContentType(MediaType.TEXT_PLAIN);
ClassPathResource fooResource = new ClassPathResource("org/springframework/http/codec/multipart/foo.txt");
HttpEntity<ClassPathResource> fooPart = new HttpEntity<>(fooResource, fooHeaders);
HttpEntity<String> barPart = new HttpEntity<>("bar");
MultiValueMap<String, Object> parts = new LinkedMultiValueMap<>();
parts.add("fooPart", fooPart);
parts.add("barPart", barPart);
return parts;
}
@RestController
@SuppressWarnings("unused")
static class MultipartController {
@PostMapping("/map")
void map(@RequestParam Map<String, Part> parts) {
assertEquals(2, parts.size());
assertEquals("foo.txt", parts.get("fooPart").getFilename().get());
assertEquals("bar", parts.get("barPart").getContentAsString().block());
}
@PostMapping("/multivaluemap")
void multiValueMap(@RequestParam MultiValueMap<String, Part> parts) {
Map<String, Part> map = parts.toSingleValueMap();
assertEquals(2, map.size());
assertEquals("foo.txt", map.get("fooPart").getFilename().get());
assertEquals("bar", map.get("barPart").getContentAsString().block());
}
@PostMapping("/partparam")
void partParam(@RequestParam Part fooPart) {
assertEquals("foo.txt", fooPart.getFilename().get());
}
@PostMapping("/part")
void part(@RequestPart Part fooPart) {
assertEquals("foo.txt", fooPart.getFilename().get());
}
}
@Configuration
@EnableWebFlux
@SuppressWarnings("unused")
static class TestConfiguration {
@Bean
public MultipartController multipartController() {
return new MultipartController();
}
}
}