@RequestPart supports java.util.Optional

Issue: SPR-12644
This commit is contained in:
Juergen Hoeller 2015-02-18 16:17:07 +01:00
parent 61cc3b5bff
commit 6ebac00f32
4 changed files with 212 additions and 65 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -309,6 +309,7 @@ public class MethodParameter {
/**
* Return the generic type of the method/constructor parameter.
* @return the parameter type (never {@code null})
* @since 3.0
*/
public Type getGenericParameterType() {
if (this.genericParameterType == null) {
@ -324,6 +325,12 @@ public class MethodParameter {
return this.genericParameterType;
}
/**
* Return the nested type of the method/constructor parameter.
* @return the parameter type (never {@code null})
* @see #getNestingLevel()
* @since 3.1
*/
public Class<?> getNestedParameterType() {
if (this.nestingLevel > 1) {
Type type = getGenericParameterType();
@ -350,6 +357,29 @@ public class MethodParameter {
}
}
/**
* Return the nested generic type of the method/constructor parameter.
* @return the parameter type (never {@code null})
* @see #getNestingLevel()
* @since 4.2
*/
public Type getNestedGenericParameterType() {
if (this.nestingLevel > 1) {
Type type = getGenericParameterType();
for (int i = 2; i <= this.nestingLevel; i++) {
if (type instanceof ParameterizedType) {
Type[] args = ((ParameterizedType) type).getActualTypeArguments();
Integer index = getTypeIndexForLevel(i);
type = args[index != null ? index : args.length - 1];
}
}
return type;
}
else {
return getGenericParameterType();
}
}
/**
* Return the annotations associated with the target method/constructor itself.
*/

View File

@ -104,7 +104,7 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
* from the given HttpInputMessage.
* @param <T> the expected type of the argument value to be created
* @param inputMessage the HTTP input message representing the current request
* @param methodParam the method argument
* @param methodParam the method parameter descriptor (may be {@code null})
* @param targetType the type of object to create, not necessarily the same as
* the method parameter type (e.g. for {@code HttpEntity<String>} method
* parameter the target type is String)
@ -113,8 +113,8 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
* @throws HttpMediaTypeNotSupportedException if no suitable message converter is found
*/
@SuppressWarnings("unchecked")
protected <T> Object readWithMessageConverters(HttpInputMessage inputMessage,
MethodParameter methodParam, Type targetType) throws IOException, HttpMediaTypeNotSupportedException {
protected <T> Object readWithMessageConverters(HttpInputMessage inputMessage, MethodParameter methodParam,
Type targetType) throws IOException, HttpMediaTypeNotSupportedException {
MediaType contentType;
try {
@ -127,7 +127,13 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
contentType = MediaType.APPLICATION_OCTET_STREAM;
}
Class<?> contextClass = methodParam.getContainingClass();
Class<?> contextClass = (methodParam != null ? methodParam.getContainingClass() : null);
Class<T> targetClass = (targetType instanceof Class<?> ? (Class<T>) targetType : null);
if (targetClass == null) {
ResolvableType resolvableType = (methodParam != null ?
ResolvableType.forMethodParameter(methodParam) : ResolvableType.forType(targetType));
targetClass = (Class<T>) resolvableType.resolve();
}
for (HttpMessageConverter<?> converter : this.messageConverters) {
if (converter instanceof GenericHttpMessageConverter) {
@ -140,14 +146,14 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
return genericConverter.read(targetType, contextClass, inputMessage);
}
}
Class<T> targetClass = (Class<T>)
ResolvableType.forMethodParameter(methodParam, targetType).resolve(Object.class);
if (converter.canRead(targetClass, contentType)) {
if (logger.isDebugEnabled()) {
logger.debug("Reading [" + targetClass.getName() + "] as \"" +
contentType + "\" using [" + converter + "]");
else if (targetClass != null) {
if (converter.canRead(targetClass, contentType)) {
if (logger.isDebugEnabled()) {
logger.debug("Reading [" + targetClass.getName() + "] as \"" +
contentType + "\" using [" + converter + "]");
}
return ((HttpMessageConverter<T>) converter).read(targetClass, inputMessage);
}
return ((HttpMessageConverter<T>) converter).read(targetClass, inputMessage);
}
}

View File

@ -20,6 +20,7 @@ import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.Part;
@ -80,12 +81,13 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
super(messageConverters);
}
/**
* Supports the following:
* <ul>
* <li>Annotated with {@code @RequestPart}
* <li>Of type {@link MultipartFile} unless annotated with {@code @RequestParam}.
* <li>Of type {@code javax.servlet.http.Part} unless annotated with {@code @RequestParam}.
* <li>Annotated with {@code @RequestPart}
* <li>Of type {@link MultipartFile} unless annotated with {@code @RequestParam}.
* <li>Of type {@code javax.servlet.http.Part} unless annotated with {@code @RequestParam}.
* </ul>
*/
@Override
@ -117,12 +119,19 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
assertIsMultipartRequest(servletRequest);
MultipartHttpServletRequest multipartRequest =
WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class);
WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class);
Class<?> paramType = parameter.getParameterType();
boolean optional = paramType.getName().equals("java.util.Optional");
if (optional) {
parameter.increaseNestingLevel();
paramType = parameter.getNestedParameterType();
}
String partName = getPartName(parameter);
Object arg;
if (MultipartFile.class.equals(parameter.getParameterType())) {
if (MultipartFile.class.equals(paramType)) {
Assert.notNull(multipartRequest, "Expected MultipartHttpServletRequest: is a MultipartResolver configured?");
arg = multipartRequest.getFile(partName);
}
@ -135,7 +144,7 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
List<MultipartFile> files = multipartRequest.getFiles(partName);
arg = files.toArray(new MultipartFile[files.size()]);
}
else if ("javax.servlet.http.Part".equals(parameter.getParameterType().getName())) {
else if ("javax.servlet.http.Part".equals(paramType.getName())) {
assertIsMultipartRequest(servletRequest);
arg = servletRequest.getPart(partName);
}
@ -150,7 +159,7 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
else {
try {
HttpInputMessage inputMessage = new RequestPartServletServerHttpRequest(servletRequest, partName);
arg = readWithMessageConverters(inputMessage, parameter, parameter.getParameterType());
arg = readWithMessageConverters(inputMessage, parameter, parameter.getNestedGenericParameterType());
WebDataBinder binder = binderFactory.createBinder(request, arg, partName);
if (arg != null) {
validate(binder, parameter);
@ -164,11 +173,14 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
}
RequestPart annot = parameter.getParameterAnnotation(RequestPart.class);
boolean isRequired = (annot == null || annot.required());
boolean isRequired = ((annot == null || annot.required()) && !optional);
if (arg == null && isRequired) {
throw new MissingServletRequestPartException(partName);
}
if (optional) {
arg = Optional.ofNullable(arg);
}
return arg;
}
@ -180,41 +192,44 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
}
}
private String getPartName(MethodParameter parameter) {
RequestPart annot = parameter.getParameterAnnotation(RequestPart.class);
private String getPartName(MethodParameter param) {
RequestPart annot = param.getParameterAnnotation(RequestPart.class);
String partName = (annot != null ? annot.value() : "");
if (partName.length() == 0) {
partName = parameter.getParameterName();
Assert.notNull(partName, "Request part name for argument type [" + parameter.getParameterType().getName() +
"] not specified, and parameter name information not found in class file either.");
partName = param.getParameterName();
if (partName == null) {
throw new IllegalArgumentException("Request part name for argument type [" +
param.getNestedParameterType().getName() +
"] not specified, and parameter name information not found in class file either.");
}
}
return partName;
}
private boolean isMultipartFileCollection(MethodParameter parameter) {
Class<?> collectionType = getCollectionParameterType(parameter);
private boolean isMultipartFileCollection(MethodParameter param) {
Class<?> collectionType = getCollectionParameterType(param);
return (collectionType != null && collectionType.equals(MultipartFile.class));
}
private boolean isMultipartFileArray(MethodParameter parameter) {
Class<?> paramType = parameter.getParameterType().getComponentType();
private boolean isMultipartFileArray(MethodParameter param) {
Class<?> paramType = param.getNestedParameterType().getComponentType();
return (paramType != null && MultipartFile.class.equals(paramType));
}
private boolean isPartCollection(MethodParameter parameter) {
Class<?> collectionType = getCollectionParameterType(parameter);
private boolean isPartCollection(MethodParameter param) {
Class<?> collectionType = getCollectionParameterType(param);
return (collectionType != null && "javax.servlet.http.Part".equals(collectionType.getName()));
}
private boolean isPartArray(MethodParameter parameter) {
Class<?> paramType = parameter.getParameterType().getComponentType();
private boolean isPartArray(MethodParameter param) {
Class<?> paramType = param.getNestedParameterType().getComponentType();
return (paramType != null && "javax.servlet.http.Part".equals(paramType.getName()));
}
private Class<?> getCollectionParameterType(MethodParameter parameter) {
Class<?> paramType = parameter.getParameterType();
private Class<?> getCollectionParameterType(MethodParameter param) {
Class<?> paramType = param.getNestedParameterType();
if (Collection.class.equals(paramType) || List.class.isAssignableFrom(paramType)){
Class<?> valueType = GenericCollectionTypeResolver.getCollectionParameterType(parameter);
Class<?> valueType = GenericCollectionTypeResolver.getCollectionParameterType(param);
if (valueType != null) {
return valueType;
}
@ -228,13 +243,13 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
* Spring's {@link org.springframework.validation.annotation.Validated},
* and custom annotations whose name starts with "Valid".
* @param binder the DataBinder to be used
* @param parameter the method parameter
* @param param the method parameter
* @throws MethodArgumentNotValidException in case of a binding error which
* is meant to be fatal (i.e. without a declared {@link Errors} parameter)
* @see #isBindingErrorFatal
*/
protected void validate(WebDataBinder binder, MethodParameter parameter) throws MethodArgumentNotValidException {
Annotation[] annotations = parameter.getParameterAnnotations();
protected void validate(WebDataBinder binder, MethodParameter param) throws MethodArgumentNotValidException {
Annotation[] annotations = param.getParameterAnnotations();
for (Annotation ann : annotations) {
Validated validatedAnn = AnnotationUtils.getAnnotation(ann, Validated.class);
if (validatedAnn != null || ann.annotationType().getSimpleName().startsWith("Valid")) {
@ -243,8 +258,8 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
binder.validate(validationHints);
BindingResult bindingResult = binder.getBindingResult();
if (bindingResult.hasErrors()) {
if (isBindingErrorFatal(parameter)) {
throw new MethodArgumentNotValidException(parameter, bindingResult);
if (isBindingErrorFatal(param)) {
throw new MethodArgumentNotValidException(param, bindingResult);
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,11 +16,11 @@
package org.springframework.web.servlet.mvc.method.annotation;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import javax.servlet.http.Part;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
@ -82,6 +82,9 @@ public class RequestPartMethodArgumentResolverTests {
private MethodParameter paramPartList;
private MethodParameter paramPartArray;
private MethodParameter paramRequestParamAnnot;
private MethodParameter optionalMultipartFile;
private MethodParameter optionalPart;
private MethodParameter optionalRequestPart;
private NativeWebRequest webRequest;
@ -89,14 +92,14 @@ public class RequestPartMethodArgumentResolverTests {
private MockHttpServletResponse servletResponse;
@SuppressWarnings("unchecked")
@Before
public void setUp() throws Exception {
Method method = getClass().getMethod("handle", SimpleBean.class, SimpleBean.class,
SimpleBean.class, MultipartFile.class, List.class, MultipartFile[].class,
Integer.TYPE, MultipartFile.class, Part.class, List.class,
Part[].class, MultipartFile.class);
Integer.TYPE, MultipartFile.class, Part.class, List.class, Part[].class,
MultipartFile.class, Optional.class, Optional.class, Optional.class);
paramRequestPart = new MethodParameter(method, 0);
paramRequestPart.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
@ -113,6 +116,11 @@ public class RequestPartMethodArgumentResolverTests {
paramPartList = new MethodParameter(method, 9);
paramPartArray = new MethodParameter(method, 10);
paramRequestParamAnnot = new MethodParameter(method, 11);
optionalMultipartFile = new MethodParameter(method, 12);
optionalMultipartFile.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
optionalPart = new MethodParameter(method, 13);
optionalPart.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
optionalRequestPart = new MethodParameter(method, 14);
messageConverter = mock(HttpMessageConverter.class);
given(messageConverter.getSupportedMediaTypes()).willReturn(Collections.singletonList(MediaType.TEXT_PLAIN));
@ -129,6 +137,7 @@ public class RequestPartMethodArgumentResolverTests {
webRequest = new ServletWebRequest(multipartRequest, servletResponse);
}
@Test
public void supportsParameter() {
assertTrue("RequestPart parameter not supported", resolver.supportsParameter(paramRequestPart));
@ -139,8 +148,8 @@ public class RequestPartMethodArgumentResolverTests {
assertTrue("MultipartFile parameter not supported", resolver.supportsParameter(paramMultipartFile));
assertTrue("List<MultipartFile> parameter not supported", resolver.supportsParameter(paramMultipartFileList));
assertTrue("MultipartFile[] parameter not supported", resolver.supportsParameter(paramMultipartFileArray));
assertFalse("non-RequestPart parameter supported", resolver.supportsParameter(paramInt));
assertFalse("@RequestParam args not supported", resolver.supportsParameter(paramRequestParamAnnot));
assertFalse("non-RequestPart parameter should not be supported", resolver.supportsParameter(paramInt));
assertFalse("@RequestParam args should not be supported", resolver.supportsParameter(paramRequestParamAnnot));
}
@Test
@ -243,21 +252,27 @@ public class RequestPartMethodArgumentResolverTests {
testResolveArgument(new SimpleBean("foo"), paramNamedRequestPart);
}
@Test
public void resolveNamedRequestPartNotPresent() throws Exception {
testResolveArgument(null, paramNamedRequestPart);
}
@Test
public void resolveRequestPartNotValid() throws Exception {
try {
testResolveArgument(new SimpleBean(null), paramValidRequestPart);
fail("Expected exception");
} catch (MethodArgumentNotValidException e) {
assertEquals("requestPart", e.getBindingResult().getObjectName());
assertEquals(1, e.getBindingResult().getErrorCount());
assertNotNull(e.getBindingResult().getFieldError("name"));
}
catch (MethodArgumentNotValidException ex) {
assertEquals("requestPart", ex.getBindingResult().getObjectName());
assertEquals(1, ex.getBindingResult().getErrorCount());
assertNotNull(ex.getBindingResult().getFieldError("name"));
}
}
@Test
public void resolveRequestPartValid() throws Exception {
testResolveArgument(new SimpleBean("foo"), paramNamedRequestPart);
testResolveArgument(new SimpleBean("foo"), paramValidRequestPart);
}
@Test
@ -265,8 +280,9 @@ public class RequestPartMethodArgumentResolverTests {
try {
testResolveArgument(null, paramValidRequestPart);
fail("Expected exception");
} catch (MissingServletRequestPartException e) {
assertEquals("requestPart", e.getRequestPartName());
}
catch (MissingServletRequestPartException ex) {
assertEquals("requestPart", ex.getRequestPartName());
}
}
@ -275,27 +291,100 @@ public class RequestPartMethodArgumentResolverTests {
testResolveArgument(new SimpleBean("foo"), paramValidRequestPart);
}
@Test(expected=MultipartException.class)
@Test(expected = MultipartException.class)
public void isMultipartRequest() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
resolver.resolveArgument(paramMultipartFile, new ModelAndViewContainer(), new ServletWebRequest(request), null);
fail("Expected exception");
}
// SPR-9079
@Test
@Test // SPR-9079
public void isMultipartRequestPut() throws Exception {
this.multipartRequest.setMethod("PUT");
Object actual = resolver.resolveArgument(paramMultipartFile, null, webRequest, null);
assertNotNull(actual);
assertSame(multipartFile1, actual);
}
private void testResolveArgument(SimpleBean argValue, MethodParameter parameter) throws IOException, Exception {
MediaType contentType = MediaType.TEXT_PLAIN;
@Test
public void resolveOptionalMultipartFileArgument() throws Exception {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
MultipartFile expected = new MockMultipartFile("optionalMultipartFile", "Hello World".getBytes());
request.addFile(expected);
webRequest = new ServletWebRequest(request);
given(messageConverter.canRead(SimpleBean.class, contentType)).willReturn(true);
Object result = resolver.resolveArgument(optionalMultipartFile, null, webRequest, null);
assertTrue(result instanceof Optional);
assertEquals("Invalid result", expected, ((Optional) result).get());
}
@Test
public void resolveOptionalMultipartFileArgumentNotPresent() throws Exception {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
webRequest = new ServletWebRequest(request);
Object result = resolver.resolveArgument(optionalMultipartFile, null, webRequest, null);
assertTrue(result instanceof Optional);
assertFalse("Invalid result", ((Optional) result).isPresent());
}
@Test
public void resolveOptionalPartArgument() throws Exception {
MockPart expected = new MockPart("optionalPart", "Hello World".getBytes());
MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("POST");
request.setContentType("multipart/form-data");
request.addPart(expected);
webRequest = new ServletWebRequest(request);
Object result = resolver.resolveArgument(optionalPart, null, webRequest, null);
assertTrue(result instanceof Optional);
assertEquals("Invalid result", expected, ((Optional) result).get());
}
@Test
public void resolveOptionalPartArgumentNotPresent() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("POST");
request.setContentType("multipart/form-data");
webRequest = new ServletWebRequest(request);
Object result = resolver.resolveArgument(optionalPart, null, webRequest, null);
assertTrue(result instanceof Optional);
assertFalse("Invalid result", ((Optional) result).isPresent());
}
@Test
public void resolveOptionalRequestPart() throws Exception {
SimpleBean simpleBean = new SimpleBean("foo");
given(messageConverter.canRead(SimpleBean.class, MediaType.TEXT_PLAIN)).willReturn(true);
given(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).willReturn(simpleBean);
ModelAndViewContainer mavContainer = new ModelAndViewContainer();
Object actualValue = resolver.resolveArgument(optionalRequestPart, mavContainer, webRequest, new ValidatingBinderFactory());
assertEquals("Invalid argument value", Optional.of(simpleBean), actualValue);
assertFalse("The requestHandled flag shouldn't change", mavContainer.isRequestHandled());
}
@Test
public void resolveOptionalRequestPartNotPresent() throws Exception {
given(messageConverter.canRead(SimpleBean.class, MediaType.TEXT_PLAIN)).willReturn(true);
given(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).willReturn(null);
ModelAndViewContainer mavContainer = new ModelAndViewContainer();
Object actualValue = resolver.resolveArgument(optionalRequestPart, mavContainer, webRequest, new ValidatingBinderFactory());
assertEquals("Invalid argument value", Optional.empty(), actualValue);
assertFalse("The requestHandled flag shouldn't change", mavContainer.isRequestHandled());
}
private void testResolveArgument(SimpleBean argValue, MethodParameter parameter) throws Exception {
given(messageConverter.canRead(SimpleBean.class, MediaType.TEXT_PLAIN)).willReturn(true);
given(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).willReturn(argValue);
ModelAndViewContainer mavContainer = new ModelAndViewContainer();
@ -305,6 +394,7 @@ public class RequestPartMethodArgumentResolverTests {
assertFalse("The requestHandled flag shouldn't change", mavContainer.isRequestHandled());
}
private static class SimpleBean {
@NotNull
@ -320,7 +410,9 @@ public class RequestPartMethodArgumentResolverTests {
}
}
private final class ValidatingBinderFactory implements WebDataBinderFactory {
@Override
public WebDataBinder createBinder(NativeWebRequest webRequest, Object target, String objectName) throws Exception {
LocalValidatorFactoryBean validator = new LocalValidatorFactoryBean();
@ -331,6 +423,7 @@ public class RequestPartMethodArgumentResolverTests {
}
}
public void handle(@RequestPart SimpleBean requestPart,
@RequestPart(value="requestPart", required=false) SimpleBean namedRequestPart,
@Valid @RequestPart("requestPart") SimpleBean validRequestPart,
@ -342,7 +435,10 @@ public class RequestPartMethodArgumentResolverTests {
Part part,
@RequestPart("part") List<Part> partList,
@RequestPart("part") Part[] partArray,
@RequestParam MultipartFile requestParamAnnot) {
@RequestParam MultipartFile requestParamAnnot,
Optional<MultipartFile> optionalMultipartFile,
Optional<Part> optionalPart,
@RequestPart("requestPart") Optional<SimpleBean> optionalRequestPart) {
}
}