SPR-8001 Recognize case when MultipartRequest is null and argument is of type MultipartFile and raise helpful exception.

This commit is contained in:
Rossen Stoyanchev 2011-07-18 13:49:47 +00:00
parent b8c723d080
commit 2568a3c2c6
3 changed files with 115 additions and 51 deletions

View File

@ -136,7 +136,10 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
} }
} }
checkMissingRequiredValue(arg, partName, parameter); if (arg == null) {
handleMissingValue(partName, parameter);
}
return arg; return arg;
} }
@ -170,19 +173,18 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
} }
/** /**
* Raises a {@link ServletRequestBindingException} if the method parameter is required * Invoked if the resolved argument value is {@code null}. The default implementation raises
* and the resolved argument value is null. * a {@link ServletRequestBindingException} if the method parameter is required.
* @param partName the name used to look up the request part
* @param param the method argument
*/ */
protected void checkMissingRequiredValue(Object argumentValue, String partName, MethodParameter parameter) protected void handleMissingValue(String partName, MethodParameter param) throws ServletRequestBindingException {
throws ServletRequestBindingException { RequestPart annot = param.getParameterAnnotation(RequestPart.class);
if (argumentValue == null) { boolean isRequired = (annot != null) ? annot.required() : true;
RequestPart annot = parameter.getParameterAnnotation(RequestPart.class); if (isRequired) {
boolean isRequired = (annot != null) ? annot.required() : true; String paramType = param.getParameterType().getName();
if (isRequired) { throw new ServletRequestBindingException(
String paramType = parameter.getParameterType().getName(); "Missing request part '" + partName + "' for method parameter type [" + paramType + "]");
throw new ServletRequestBindingException(
"Missing request part '" + partName + "' for method parameter type [" + paramType + "]");
}
} }
} }

View File

@ -17,6 +17,7 @@
package org.springframework.web.method.annotation.support; package org.springframework.web.method.annotation.support;
import java.beans.PropertyEditor; import java.beans.PropertyEditor;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -25,6 +26,7 @@ import javax.servlet.http.HttpServletRequest;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.core.GenericCollectionTypeResolver;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -35,6 +37,7 @@ import org.springframework.web.bind.annotation.ValueConstants;
import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest; import org.springframework.web.multipart.MultipartHttpServletRequest;
import org.springframework.web.multipart.MultipartRequest;
import org.springframework.web.multipart.MultipartResolver; import org.springframework.web.multipart.MultipartResolver;
import org.springframework.web.util.WebUtils; import org.springframework.web.util.WebUtils;
@ -124,29 +127,59 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueMethod
@Override @Override
protected Object resolveName(String name, MethodParameter parameter, NativeWebRequest webRequest) throws Exception { protected Object resolveName(String name, MethodParameter parameter, NativeWebRequest webRequest) throws Exception {
Object arg;
HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
MultipartHttpServletRequest multipartRequest = MultipartHttpServletRequest multipartRequest =
WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class); WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class);
if (multipartRequest != null) { if (MultipartFile.class.equals(parameter.getParameterType())) {
List<MultipartFile> files = multipartRequest.getFiles(name); assertMultipartRequest(multipartRequest, webRequest);
if (!files.isEmpty()) { arg = multipartRequest.getFile(name);
return (files.size() == 1 ? files.get(0) : files); }
else if (isMultipartFileCollection(parameter)) {
assertMultipartRequest(multipartRequest, webRequest);
arg = multipartRequest.getFiles(name);
}
else if ("javax.servlet.http.Part".equals(parameter.getParameterType().getName())) {
arg = servletRequest.getPart(name);
}
else {
arg = null;
if (multipartRequest != null) {
List<MultipartFile> files = multipartRequest.getFiles(name);
if (!files.isEmpty()) {
arg = (files.size() == 1 ? files.get(0) : files);
}
}
if (arg == null) {
String[] paramValues = webRequest.getParameterValues(name);
if (paramValues != null) {
arg = paramValues.length == 1 ? paramValues[0] : paramValues;
}
} }
} }
if ("javax.servlet.http.Part".equals(parameter.getParameterType().getName())) { return arg;
return servletRequest.getPart(name); }
private void assertMultipartRequest(MultipartHttpServletRequest multipartRequest, NativeWebRequest request) {
if (multipartRequest == null) {
throw new IllegalStateException("Current request is not of type [" + MultipartRequest.class.getName()
+ "]: " + request + ". Do you have a MultipartResolver configured?");
} }
}
String[] paramValues = webRequest.getParameterValues(name);
if (paramValues != null) { private boolean isMultipartFileCollection(MethodParameter parameter) {
return paramValues.length == 1 ? paramValues[0] : paramValues; Class<?> paramType = parameter.getParameterType();
} if (Collection.class.equals(paramType) || List.class.isAssignableFrom(paramType)){
else { Class<?> valueType = GenericCollectionTypeResolver.getCollectionParameterType(parameter);
return null; if (valueType != null && valueType.equals(MultipartFile.class)) {
return true;
}
} }
return false;
} }
@Override @Override

View File

@ -20,8 +20,11 @@ import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Map; import java.util.Map;
import javax.servlet.http.Part; import javax.servlet.http.Part;
@ -30,6 +33,7 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.core.LocalVariableTableParameterNameDiscoverer; import org.springframework.core.LocalVariableTableParameterNameDiscoverer;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockMultipartFile; import org.springframework.mock.web.MockMultipartFile;
@ -58,7 +62,8 @@ public class RequestParamMethodArgumentResolverTests {
private MethodParameter paramMap; private MethodParameter paramMap;
private MethodParameter paramStringNotAnnot; private MethodParameter paramStringNotAnnot;
private MethodParameter paramMultipartFileNotAnnot; private MethodParameter paramMultipartFileNotAnnot;
private MethodParameter paramPartNotAnnot; private MethodParameter paramMultipartFileList;
private MethodParameter paramServlet30Part;
private NativeWebRequest webRequest; private NativeWebRequest webRequest;
@ -67,9 +72,11 @@ public class RequestParamMethodArgumentResolverTests {
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
resolver = new RequestParamMethodArgumentResolver(null, true); resolver = new RequestParamMethodArgumentResolver(null, true);
ParameterNameDiscoverer paramNameDiscoverer = new LocalVariableTableParameterNameDiscoverer();
Method method = getClass().getMethod("params", String.class, String[].class, Map.class, MultipartFile.class, Method method = getClass().getMethod("params", String.class, String[].class, Map.class, MultipartFile.class,
Map.class, String.class, MultipartFile.class, Part.class); Map.class, String.class, MultipartFile.class, List.class, Part.class);
paramNamedDefaultValueString = new MethodParameter(method, 0); paramNamedDefaultValueString = new MethodParameter(method, 0);
paramNamedStringArray = new MethodParameter(method, 1); paramNamedStringArray = new MethodParameter(method, 1);
@ -77,11 +84,13 @@ public class RequestParamMethodArgumentResolverTests {
paramMultiPartFile = new MethodParameter(method, 3); paramMultiPartFile = new MethodParameter(method, 3);
paramMap = new MethodParameter(method, 4); paramMap = new MethodParameter(method, 4);
paramStringNotAnnot = new MethodParameter(method, 5); paramStringNotAnnot = new MethodParameter(method, 5);
paramStringNotAnnot.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); paramStringNotAnnot.initParameterNameDiscovery(paramNameDiscoverer);
paramMultipartFileNotAnnot = new MethodParameter(method, 6); paramMultipartFileNotAnnot = new MethodParameter(method, 6);
paramMultipartFileNotAnnot.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); paramMultipartFileNotAnnot.initParameterNameDiscovery(paramNameDiscoverer);
paramPartNotAnnot = new MethodParameter(method, 7); paramMultipartFileList = new MethodParameter(method, 7);
paramPartNotAnnot.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); paramMultipartFileList.initParameterNameDiscovery(paramNameDiscoverer);
paramServlet30Part = new MethodParameter(method, 8);
paramServlet30Part.initParameterNameDiscovery(paramNameDiscoverer);
request = new MockHttpServletRequest(); request = new MockHttpServletRequest();
webRequest = new ServletWebRequest(request, new MockHttpServletResponse()); webRequest = new ServletWebRequest(request, new MockHttpServletResponse());
@ -97,14 +106,14 @@ public class RequestParamMethodArgumentResolverTests {
assertFalse("non-@RequestParam parameter supported", resolver.supportsParameter(paramMap)); assertFalse("non-@RequestParam parameter supported", resolver.supportsParameter(paramMap));
assertTrue("Simple type params supported w/o annotations", resolver.supportsParameter(paramStringNotAnnot)); assertTrue("Simple type params supported w/o annotations", resolver.supportsParameter(paramStringNotAnnot));
assertTrue("MultipartFile parameter not supported", resolver.supportsParameter(paramMultipartFileNotAnnot)); assertTrue("MultipartFile parameter not supported", resolver.supportsParameter(paramMultipartFileNotAnnot));
assertTrue("Part parameter not supported", resolver.supportsParameter(paramPartNotAnnot)); assertTrue("Part parameter not supported", resolver.supportsParameter(paramServlet30Part));
resolver = new RequestParamMethodArgumentResolver(null, false); resolver = new RequestParamMethodArgumentResolver(null, false);
assertFalse(resolver.supportsParameter(paramStringNotAnnot)); assertFalse(resolver.supportsParameter(paramStringNotAnnot));
} }
@Test @Test
public void resolveStringArgument() throws Exception { public void resolveString() throws Exception {
String expected = "foo"; String expected = "foo";
request.addParameter("name", expected); request.addParameter("name", expected);
@ -115,7 +124,7 @@ public class RequestParamMethodArgumentResolverTests {
} }
@Test @Test
public void resolveStringArrayArgument() throws Exception { public void resolveStringArray() throws Exception {
String[] expected = new String[]{"foo", "bar"}; String[] expected = new String[]{"foo", "bar"};
request.addParameter("name", expected); request.addParameter("name", expected);
@ -126,7 +135,7 @@ public class RequestParamMethodArgumentResolverTests {
} }
@Test @Test
public void resolveMultipartFileArgument() throws Exception { public void resolveMultipartFile() throws Exception {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
MultipartFile expected = new MockMultipartFile("file", "Hello World".getBytes()); MultipartFile expected = new MockMultipartFile("file", "Hello World".getBytes());
request.addFile(expected); request.addFile(expected);
@ -139,9 +148,9 @@ public class RequestParamMethodArgumentResolverTests {
} }
@Test @Test
public void resolveMultipartFileNotAnnotArgument() throws Exception { public void resolveMultipartFileNotAnnot() throws Exception {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
MultipartFile expected = new MockMultipartFile("paramMultipartFileNotAnnot", "Hello World".getBytes()); MultipartFile expected = new MockMultipartFile("multipartFileNotAnnot", "Hello World".getBytes());
request.addFile(expected); request.addFile(expected);
webRequest = new ServletWebRequest(request); webRequest = new ServletWebRequest(request);
@ -152,13 +161,34 @@ public class RequestParamMethodArgumentResolverTests {
} }
@Test @Test
public void resolvePartArgument() throws Exception { public void resolveMultipartFileList() throws Exception {
MockPart expected = new MockPart("paramPartNotAnnot", "Hello World".getBytes()); MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
MultipartFile expected1 = new MockMultipartFile("multipartFileList", "Hello World 1".getBytes());
MultipartFile expected2 = new MockMultipartFile("multipartFileList", "Hello World 2".getBytes());
request.addFile(expected1);
request.addFile(expected2);
webRequest = new ServletWebRequest(request);
Object result = resolver.resolveArgument(paramMultipartFileList, null, webRequest, null);
assertTrue(result instanceof List);
assertEquals(Arrays.asList(expected1, expected2), result);
}
@Test(expected = IllegalStateException.class)
public void missingMultipartFile() throws Exception {
resolver.resolveArgument(paramMultiPartFile, null, webRequest, null);
fail("Expected exception");
}
@Test
public void resolveServlet30Part() throws Exception {
MockPart expected = new MockPart("servlet30Part", "Hello World".getBytes());
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.addPart(expected); request.addPart(expected);
webRequest = new ServletWebRequest(request); webRequest = new ServletWebRequest(request);
Object result = resolver.resolveArgument(paramPartNotAnnot, null, webRequest, null); Object result = resolver.resolveArgument(paramServlet30Part, null, webRequest, null);
assertTrue(result instanceof Part); assertTrue(result instanceof Part);
assertEquals("Invalid result", expected, result); assertEquals("Invalid result", expected, result);
@ -173,16 +203,14 @@ public class RequestParamMethodArgumentResolverTests {
} }
@Test(expected = MissingServletRequestParameterException.class) @Test(expected = MissingServletRequestParameterException.class)
public void notFound() throws Exception { public void missingRequestParam() throws Exception {
Object result = resolver.resolveArgument(paramNamedStringArray, null, webRequest, null); resolver.resolveArgument(paramNamedStringArray, null, webRequest, null);
fail("Expected exception");
assertTrue(result instanceof String);
assertEquals("Invalid result", "bar", result);
} }
@Test @Test
public void resolveSimpleTypeParam() throws Exception { public void resolveSimpleTypeParam() throws Exception {
request.setParameter("paramStringNotAnnot", "plainValue"); request.setParameter("stringNotAnnot", "plainValue");
Object result = resolver.resolveArgument(paramStringNotAnnot, null, webRequest, null); Object result = resolver.resolveArgument(paramStringNotAnnot, null, webRequest, null);
assertTrue(result instanceof String); assertTrue(result instanceof String);
@ -194,9 +222,10 @@ public class RequestParamMethodArgumentResolverTests {
@RequestParam("name") Map<?, ?> param3, @RequestParam("name") Map<?, ?> param3,
@RequestParam(value = "file") MultipartFile param4, @RequestParam(value = "file") MultipartFile param4,
@RequestParam Map<?, ?> param5, @RequestParam Map<?, ?> param5,
String paramStringNotAnnot, String stringNotAnnot,
MultipartFile paramMultipartFileNotAnnot, MultipartFile multipartFileNotAnnot,
Part paramPartNotAnnot) { List<MultipartFile> multipartFileList,
Part servlet30Part) {
} }
} }