diff --git a/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolver.java b/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolver.java index eae6068073c..4bf54cae7d6 100644 --- a/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolver.java +++ b/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/method/annotation/support/RequestPartMethodArgumentResolver.java @@ -136,7 +136,10 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM } } - checkMissingRequiredValue(arg, partName, parameter); + if (arg == null) { + handleMissingValue(partName, parameter); + } + return arg; } @@ -170,19 +173,18 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM } /** - * Raises a {@link ServletRequestBindingException} if the method parameter is required - * and the resolved argument value is null. + * Invoked if the resolved argument value is {@code null}. The default implementation raises + * a {@link ServletRequestBindingException} if the method parameter is required. + * @param partName the name used to look up the request part + * @param param the method argument */ - protected void checkMissingRequiredValue(Object argumentValue, String partName, MethodParameter parameter) - throws ServletRequestBindingException { - if (argumentValue == null) { - RequestPart annot = parameter.getParameterAnnotation(RequestPart.class); - boolean isRequired = (annot != null) ? annot.required() : true; - if (isRequired) { - String paramType = parameter.getParameterType().getName(); - throw new ServletRequestBindingException( - "Missing request part '" + partName + "' for method parameter type [" + paramType + "]"); - } + protected void handleMissingValue(String partName, MethodParameter param) throws ServletRequestBindingException { + RequestPart annot = param.getParameterAnnotation(RequestPart.class); + boolean isRequired = (annot != null) ? annot.required() : true; + if (isRequired) { + String paramType = param.getParameterType().getName(); + throw new ServletRequestBindingException( + "Missing request part '" + partName + "' for method parameter type [" + paramType + "]"); } } diff --git a/org.springframework.web/src/main/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolver.java b/org.springframework.web/src/main/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolver.java index 370dc478b9d..110ef0fac18 100644 --- a/org.springframework.web/src/main/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolver.java +++ b/org.springframework.web/src/main/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolver.java @@ -17,6 +17,7 @@ package org.springframework.web.method.annotation.support; import java.beans.PropertyEditor; +import java.util.Collection; import java.util.List; import java.util.Map; @@ -25,6 +26,7 @@ import javax.servlet.http.HttpServletRequest; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.core.GenericCollectionTypeResolver; import org.springframework.core.MethodParameter; import org.springframework.core.convert.converter.Converter; import org.springframework.util.StringUtils; @@ -35,6 +37,7 @@ import org.springframework.web.bind.annotation.ValueConstants; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartHttpServletRequest; +import org.springframework.web.multipart.MultipartRequest; import org.springframework.web.multipart.MultipartResolver; import org.springframework.web.util.WebUtils; @@ -124,29 +127,59 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueMethod @Override protected Object resolveName(String name, MethodParameter parameter, NativeWebRequest webRequest) throws Exception { + + Object arg; HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); MultipartHttpServletRequest multipartRequest = WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class); - if (multipartRequest != null) { - List files = multipartRequest.getFiles(name); - if (!files.isEmpty()) { - return (files.size() == 1 ? files.get(0) : files); + if (MultipartFile.class.equals(parameter.getParameterType())) { + assertMultipartRequest(multipartRequest, webRequest); + arg = multipartRequest.getFile(name); + } + else if (isMultipartFileCollection(parameter)) { + assertMultipartRequest(multipartRequest, webRequest); + arg = multipartRequest.getFiles(name); + } + else if ("javax.servlet.http.Part".equals(parameter.getParameterType().getName())) { + arg = servletRequest.getPart(name); + } + else { + arg = null; + if (multipartRequest != null) { + List files = multipartRequest.getFiles(name); + if (!files.isEmpty()) { + arg = (files.size() == 1 ? files.get(0) : files); + } + } + if (arg == null) { + String[] paramValues = webRequest.getParameterValues(name); + if (paramValues != null) { + arg = paramValues.length == 1 ? paramValues[0] : paramValues; + } } } - if ("javax.servlet.http.Part".equals(parameter.getParameterType().getName())) { - return servletRequest.getPart(name); + return arg; + } + + private void assertMultipartRequest(MultipartHttpServletRequest multipartRequest, NativeWebRequest request) { + if (multipartRequest == null) { + throw new IllegalStateException("Current request is not of type [" + MultipartRequest.class.getName() + + "]: " + request + ". Do you have a MultipartResolver configured?"); } - - String[] paramValues = webRequest.getParameterValues(name); - if (paramValues != null) { - return paramValues.length == 1 ? paramValues[0] : paramValues; - } - else { - return null; + } + + private boolean isMultipartFileCollection(MethodParameter parameter) { + Class paramType = parameter.getParameterType(); + if (Collection.class.equals(paramType) || List.class.isAssignableFrom(paramType)){ + Class valueType = GenericCollectionTypeResolver.getCollectionParameterType(parameter); + if (valueType != null && valueType.equals(MultipartFile.class)) { + return true; + } } + return false; } @Override diff --git a/org.springframework.web/src/test/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolverTests.java b/org.springframework.web/src/test/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolverTests.java index a0a88b11182..3e8364ff88f 100644 --- a/org.springframework.web/src/test/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolverTests.java +++ b/org.springframework.web/src/test/java/org/springframework/web/method/annotation/support/RequestParamMethodArgumentResolverTests.java @@ -20,8 +20,11 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; import java.util.Map; import javax.servlet.http.Part; @@ -30,6 +33,7 @@ import org.junit.Before; import org.junit.Test; import org.springframework.core.LocalVariableTableParameterNameDiscoverer; import org.springframework.core.MethodParameter; +import org.springframework.core.ParameterNameDiscoverer; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockMultipartFile; @@ -58,7 +62,8 @@ public class RequestParamMethodArgumentResolverTests { private MethodParameter paramMap; private MethodParameter paramStringNotAnnot; private MethodParameter paramMultipartFileNotAnnot; - private MethodParameter paramPartNotAnnot; + private MethodParameter paramMultipartFileList; + private MethodParameter paramServlet30Part; private NativeWebRequest webRequest; @@ -67,9 +72,11 @@ public class RequestParamMethodArgumentResolverTests { @Before public void setUp() throws Exception { resolver = new RequestParamMethodArgumentResolver(null, true); - + + ParameterNameDiscoverer paramNameDiscoverer = new LocalVariableTableParameterNameDiscoverer(); + Method method = getClass().getMethod("params", String.class, String[].class, Map.class, MultipartFile.class, - Map.class, String.class, MultipartFile.class, Part.class); + Map.class, String.class, MultipartFile.class, List.class, Part.class); paramNamedDefaultValueString = new MethodParameter(method, 0); paramNamedStringArray = new MethodParameter(method, 1); @@ -77,11 +84,13 @@ public class RequestParamMethodArgumentResolverTests { paramMultiPartFile = new MethodParameter(method, 3); paramMap = new MethodParameter(method, 4); paramStringNotAnnot = new MethodParameter(method, 5); - paramStringNotAnnot.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); + paramStringNotAnnot.initParameterNameDiscovery(paramNameDiscoverer); paramMultipartFileNotAnnot = new MethodParameter(method, 6); - paramMultipartFileNotAnnot.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); - paramPartNotAnnot = new MethodParameter(method, 7); - paramPartNotAnnot.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); + paramMultipartFileNotAnnot.initParameterNameDiscovery(paramNameDiscoverer); + paramMultipartFileList = new MethodParameter(method, 7); + paramMultipartFileList.initParameterNameDiscovery(paramNameDiscoverer); + paramServlet30Part = new MethodParameter(method, 8); + paramServlet30Part.initParameterNameDiscovery(paramNameDiscoverer); request = new MockHttpServletRequest(); webRequest = new ServletWebRequest(request, new MockHttpServletResponse()); @@ -97,14 +106,14 @@ public class RequestParamMethodArgumentResolverTests { assertFalse("non-@RequestParam parameter supported", resolver.supportsParameter(paramMap)); assertTrue("Simple type params supported w/o annotations", resolver.supportsParameter(paramStringNotAnnot)); assertTrue("MultipartFile parameter not supported", resolver.supportsParameter(paramMultipartFileNotAnnot)); - assertTrue("Part parameter not supported", resolver.supportsParameter(paramPartNotAnnot)); + assertTrue("Part parameter not supported", resolver.supportsParameter(paramServlet30Part)); resolver = new RequestParamMethodArgumentResolver(null, false); assertFalse(resolver.supportsParameter(paramStringNotAnnot)); } @Test - public void resolveStringArgument() throws Exception { + public void resolveString() throws Exception { String expected = "foo"; request.addParameter("name", expected); @@ -115,7 +124,7 @@ public class RequestParamMethodArgumentResolverTests { } @Test - public void resolveStringArrayArgument() throws Exception { + public void resolveStringArray() throws Exception { String[] expected = new String[]{"foo", "bar"}; request.addParameter("name", expected); @@ -126,7 +135,7 @@ public class RequestParamMethodArgumentResolverTests { } @Test - public void resolveMultipartFileArgument() throws Exception { + public void resolveMultipartFile() throws Exception { MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); MultipartFile expected = new MockMultipartFile("file", "Hello World".getBytes()); request.addFile(expected); @@ -139,9 +148,9 @@ public class RequestParamMethodArgumentResolverTests { } @Test - public void resolveMultipartFileNotAnnotArgument() throws Exception { + public void resolveMultipartFileNotAnnot() throws Exception { MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); - MultipartFile expected = new MockMultipartFile("paramMultipartFileNotAnnot", "Hello World".getBytes()); + MultipartFile expected = new MockMultipartFile("multipartFileNotAnnot", "Hello World".getBytes()); request.addFile(expected); webRequest = new ServletWebRequest(request); @@ -152,13 +161,34 @@ public class RequestParamMethodArgumentResolverTests { } @Test - public void resolvePartArgument() throws Exception { - MockPart expected = new MockPart("paramPartNotAnnot", "Hello World".getBytes()); + public void resolveMultipartFileList() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected1 = new MockMultipartFile("multipartFileList", "Hello World 1".getBytes()); + MultipartFile expected2 = new MockMultipartFile("multipartFileList", "Hello World 2".getBytes()); + request.addFile(expected1); + request.addFile(expected2); + webRequest = new ServletWebRequest(request); + + Object result = resolver.resolveArgument(paramMultipartFileList, null, webRequest, null); + + assertTrue(result instanceof List); + assertEquals(Arrays.asList(expected1, expected2), result); + } + + @Test(expected = IllegalStateException.class) + public void missingMultipartFile() throws Exception { + resolver.resolveArgument(paramMultiPartFile, null, webRequest, null); + fail("Expected exception"); + } + + @Test + public void resolveServlet30Part() throws Exception { + MockPart expected = new MockPart("servlet30Part", "Hello World".getBytes()); MockHttpServletRequest request = new MockHttpServletRequest(); request.addPart(expected); webRequest = new ServletWebRequest(request); - Object result = resolver.resolveArgument(paramPartNotAnnot, null, webRequest, null); + Object result = resolver.resolveArgument(paramServlet30Part, null, webRequest, null); assertTrue(result instanceof Part); assertEquals("Invalid result", expected, result); @@ -173,16 +203,14 @@ public class RequestParamMethodArgumentResolverTests { } @Test(expected = MissingServletRequestParameterException.class) - public void notFound() throws Exception { - Object result = resolver.resolveArgument(paramNamedStringArray, null, webRequest, null); - - assertTrue(result instanceof String); - assertEquals("Invalid result", "bar", result); + public void missingRequestParam() throws Exception { + resolver.resolveArgument(paramNamedStringArray, null, webRequest, null); + fail("Expected exception"); } @Test public void resolveSimpleTypeParam() throws Exception { - request.setParameter("paramStringNotAnnot", "plainValue"); + request.setParameter("stringNotAnnot", "plainValue"); Object result = resolver.resolveArgument(paramStringNotAnnot, null, webRequest, null); assertTrue(result instanceof String); @@ -194,9 +222,10 @@ public class RequestParamMethodArgumentResolverTests { @RequestParam("name") Map param3, @RequestParam(value = "file") MultipartFile param4, @RequestParam Map param5, - String paramStringNotAnnot, - MultipartFile paramMultipartFileNotAnnot, - Part paramPartNotAnnot) { + String stringNotAnnot, + MultipartFile multipartFileNotAnnot, + List multipartFileList, + Part servlet30Part) { } }