diff --git a/spring-core/src/main/java/org/springframework/core/MethodParameter.java b/spring-core/src/main/java/org/springframework/core/MethodParameter.java index 952482c7ed..5a54a6c2fd 100644 --- a/spring-core/src/main/java/org/springframework/core/MethodParameter.java +++ b/spring-core/src/main/java/org/springframework/core/MethodParameter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -309,6 +309,7 @@ public class MethodParameter { /** * Return the generic type of the method/constructor parameter. * @return the parameter type (never {@code null}) + * @since 3.0 */ public Type getGenericParameterType() { if (this.genericParameterType == null) { @@ -324,6 +325,12 @@ public class MethodParameter { return this.genericParameterType; } + /** + * Return the nested type of the method/constructor parameter. + * @return the parameter type (never {@code null}) + * @see #getNestingLevel() + * @since 3.1 + */ public Class getNestedParameterType() { if (this.nestingLevel > 1) { Type type = getGenericParameterType(); @@ -350,6 +357,29 @@ public class MethodParameter { } } + /** + * Return the nested generic type of the method/constructor parameter. + * @return the parameter type (never {@code null}) + * @see #getNestingLevel() + * @since 4.2 + */ + public Type getNestedGenericParameterType() { + if (this.nestingLevel > 1) { + Type type = getGenericParameterType(); + for (int i = 2; i <= this.nestingLevel; i++) { + if (type instanceof ParameterizedType) { + Type[] args = ((ParameterizedType) type).getActualTypeArguments(); + Integer index = getTypeIndexForLevel(i); + type = args[index != null ? index : args.length - 1]; + } + } + return type; + } + else { + return getGenericParameterType(); + } + } + /** * Return the annotations associated with the target method/constructor itself. */ diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java index 2815515219..335810df72 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java @@ -104,7 +104,7 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements * from the given HttpInputMessage. * @param the expected type of the argument value to be created * @param inputMessage the HTTP input message representing the current request - * @param methodParam the method argument + * @param methodParam the method parameter descriptor (may be {@code null}) * @param targetType the type of object to create, not necessarily the same as * the method parameter type (e.g. for {@code HttpEntity} method * parameter the target type is String) @@ -113,8 +113,8 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements * @throws HttpMediaTypeNotSupportedException if no suitable message converter is found */ @SuppressWarnings("unchecked") - protected Object readWithMessageConverters(HttpInputMessage inputMessage, - MethodParameter methodParam, Type targetType) throws IOException, HttpMediaTypeNotSupportedException { + protected Object readWithMessageConverters(HttpInputMessage inputMessage, MethodParameter methodParam, + Type targetType) throws IOException, HttpMediaTypeNotSupportedException { MediaType contentType; try { @@ -127,7 +127,13 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements contentType = MediaType.APPLICATION_OCTET_STREAM; } - Class contextClass = methodParam.getContainingClass(); + Class contextClass = (methodParam != null ? methodParam.getContainingClass() : null); + Class targetClass = (targetType instanceof Class ? (Class) targetType : null); + if (targetClass == null) { + ResolvableType resolvableType = (methodParam != null ? + ResolvableType.forMethodParameter(methodParam) : ResolvableType.forType(targetType)); + targetClass = (Class) resolvableType.resolve(); + } for (HttpMessageConverter converter : this.messageConverters) { if (converter instanceof GenericHttpMessageConverter) { @@ -140,14 +146,14 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements return genericConverter.read(targetType, contextClass, inputMessage); } } - Class targetClass = (Class) - ResolvableType.forMethodParameter(methodParam, targetType).resolve(Object.class); - if (converter.canRead(targetClass, contentType)) { - if (logger.isDebugEnabled()) { - logger.debug("Reading [" + targetClass.getName() + "] as \"" + - contentType + "\" using [" + converter + "]"); + else if (targetClass != null) { + if (converter.canRead(targetClass, contentType)) { + if (logger.isDebugEnabled()) { + logger.debug("Reading [" + targetClass.getName() + "] as \"" + + contentType + "\" using [" + converter + "]"); + } + return ((HttpMessageConverter) converter).read(targetClass, inputMessage); } - return ((HttpMessageConverter) converter).read(targetClass, inputMessage); } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java index 4a5f15c9a4..dfb7522801 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java @@ -20,6 +20,7 @@ import java.lang.annotation.Annotation; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Optional; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.Part; @@ -80,12 +81,13 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM super(messageConverters); } + /** * Supports the following: *
    - *
  • Annotated with {@code @RequestPart} - *
  • Of type {@link MultipartFile} unless annotated with {@code @RequestParam}. - *
  • Of type {@code javax.servlet.http.Part} unless annotated with {@code @RequestParam}. + *
  • Annotated with {@code @RequestPart} + *
  • Of type {@link MultipartFile} unless annotated with {@code @RequestParam}. + *
  • Of type {@code javax.servlet.http.Part} unless annotated with {@code @RequestParam}. *
*/ @Override @@ -117,12 +119,19 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM assertIsMultipartRequest(servletRequest); MultipartHttpServletRequest multipartRequest = - WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class); + WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class); + + Class paramType = parameter.getParameterType(); + boolean optional = paramType.getName().equals("java.util.Optional"); + if (optional) { + parameter.increaseNestingLevel(); + paramType = parameter.getNestedParameterType(); + } String partName = getPartName(parameter); Object arg; - if (MultipartFile.class.equals(parameter.getParameterType())) { + if (MultipartFile.class.equals(paramType)) { Assert.notNull(multipartRequest, "Expected MultipartHttpServletRequest: is a MultipartResolver configured?"); arg = multipartRequest.getFile(partName); } @@ -135,7 +144,7 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM List files = multipartRequest.getFiles(partName); arg = files.toArray(new MultipartFile[files.size()]); } - else if ("javax.servlet.http.Part".equals(parameter.getParameterType().getName())) { + else if ("javax.servlet.http.Part".equals(paramType.getName())) { assertIsMultipartRequest(servletRequest); arg = servletRequest.getPart(partName); } @@ -150,7 +159,7 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM else { try { HttpInputMessage inputMessage = new RequestPartServletServerHttpRequest(servletRequest, partName); - arg = readWithMessageConverters(inputMessage, parameter, parameter.getParameterType()); + arg = readWithMessageConverters(inputMessage, parameter, parameter.getNestedGenericParameterType()); WebDataBinder binder = binderFactory.createBinder(request, arg, partName); if (arg != null) { validate(binder, parameter); @@ -164,11 +173,14 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM } RequestPart annot = parameter.getParameterAnnotation(RequestPart.class); - boolean isRequired = (annot == null || annot.required()); + boolean isRequired = ((annot == null || annot.required()) && !optional); if (arg == null && isRequired) { throw new MissingServletRequestPartException(partName); } + if (optional) { + arg = Optional.ofNullable(arg); + } return arg; } @@ -180,41 +192,44 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM } } - private String getPartName(MethodParameter parameter) { - RequestPart annot = parameter.getParameterAnnotation(RequestPart.class); + private String getPartName(MethodParameter param) { + RequestPart annot = param.getParameterAnnotation(RequestPart.class); String partName = (annot != null ? annot.value() : ""); if (partName.length() == 0) { - partName = parameter.getParameterName(); - Assert.notNull(partName, "Request part name for argument type [" + parameter.getParameterType().getName() + - "] not specified, and parameter name information not found in class file either."); + partName = param.getParameterName(); + if (partName == null) { + throw new IllegalArgumentException("Request part name for argument type [" + + param.getNestedParameterType().getName() + + "] not specified, and parameter name information not found in class file either."); + } } return partName; } - private boolean isMultipartFileCollection(MethodParameter parameter) { - Class collectionType = getCollectionParameterType(parameter); + private boolean isMultipartFileCollection(MethodParameter param) { + Class collectionType = getCollectionParameterType(param); return (collectionType != null && collectionType.equals(MultipartFile.class)); } - private boolean isMultipartFileArray(MethodParameter parameter) { - Class paramType = parameter.getParameterType().getComponentType(); + private boolean isMultipartFileArray(MethodParameter param) { + Class paramType = param.getNestedParameterType().getComponentType(); return (paramType != null && MultipartFile.class.equals(paramType)); } - private boolean isPartCollection(MethodParameter parameter) { - Class collectionType = getCollectionParameterType(parameter); + private boolean isPartCollection(MethodParameter param) { + Class collectionType = getCollectionParameterType(param); return (collectionType != null && "javax.servlet.http.Part".equals(collectionType.getName())); } - private boolean isPartArray(MethodParameter parameter) { - Class paramType = parameter.getParameterType().getComponentType(); + private boolean isPartArray(MethodParameter param) { + Class paramType = param.getNestedParameterType().getComponentType(); return (paramType != null && "javax.servlet.http.Part".equals(paramType.getName())); } - private Class getCollectionParameterType(MethodParameter parameter) { - Class paramType = parameter.getParameterType(); + private Class getCollectionParameterType(MethodParameter param) { + Class paramType = param.getNestedParameterType(); if (Collection.class.equals(paramType) || List.class.isAssignableFrom(paramType)){ - Class valueType = GenericCollectionTypeResolver.getCollectionParameterType(parameter); + Class valueType = GenericCollectionTypeResolver.getCollectionParameterType(param); if (valueType != null) { return valueType; } @@ -228,13 +243,13 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM * Spring's {@link org.springframework.validation.annotation.Validated}, * and custom annotations whose name starts with "Valid". * @param binder the DataBinder to be used - * @param parameter the method parameter + * @param param the method parameter * @throws MethodArgumentNotValidException in case of a binding error which * is meant to be fatal (i.e. without a declared {@link Errors} parameter) * @see #isBindingErrorFatal */ - protected void validate(WebDataBinder binder, MethodParameter parameter) throws MethodArgumentNotValidException { - Annotation[] annotations = parameter.getParameterAnnotations(); + protected void validate(WebDataBinder binder, MethodParameter param) throws MethodArgumentNotValidException { + Annotation[] annotations = param.getParameterAnnotations(); for (Annotation ann : annotations) { Validated validatedAnn = AnnotationUtils.getAnnotation(ann, Validated.class); if (validatedAnn != null || ann.annotationType().getSimpleName().startsWith("Valid")) { @@ -243,8 +258,8 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM binder.validate(validationHints); BindingResult bindingResult = binder.getBindingResult(); if (bindingResult.hasErrors()) { - if (isBindingErrorFatal(parameter)) { - throw new MethodArgumentNotValidException(parameter, bindingResult); + if (isBindingErrorFatal(param)) { + throw new MethodArgumentNotValidException(param, bindingResult); } } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java index 1fd3b01fac..a55c4ad90f 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,11 @@ package org.springframework.web.servlet.mvc.method.annotation; -import java.io.IOException; import java.lang.reflect.Method; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; import javax.servlet.http.Part; import javax.validation.Valid; import javax.validation.constraints.NotNull; @@ -82,6 +82,9 @@ public class RequestPartMethodArgumentResolverTests { private MethodParameter paramPartList; private MethodParameter paramPartArray; private MethodParameter paramRequestParamAnnot; + private MethodParameter optionalMultipartFile; + private MethodParameter optionalPart; + private MethodParameter optionalRequestPart; private NativeWebRequest webRequest; @@ -89,14 +92,14 @@ public class RequestPartMethodArgumentResolverTests { private MockHttpServletResponse servletResponse; + @SuppressWarnings("unchecked") @Before public void setUp() throws Exception { - Method method = getClass().getMethod("handle", SimpleBean.class, SimpleBean.class, SimpleBean.class, MultipartFile.class, List.class, MultipartFile[].class, - Integer.TYPE, MultipartFile.class, Part.class, List.class, - Part[].class, MultipartFile.class); + Integer.TYPE, MultipartFile.class, Part.class, List.class, Part[].class, + MultipartFile.class, Optional.class, Optional.class, Optional.class); paramRequestPart = new MethodParameter(method, 0); paramRequestPart.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); @@ -113,6 +116,11 @@ public class RequestPartMethodArgumentResolverTests { paramPartList = new MethodParameter(method, 9); paramPartArray = new MethodParameter(method, 10); paramRequestParamAnnot = new MethodParameter(method, 11); + optionalMultipartFile = new MethodParameter(method, 12); + optionalMultipartFile.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); + optionalPart = new MethodParameter(method, 13); + optionalPart.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); + optionalRequestPart = new MethodParameter(method, 14); messageConverter = mock(HttpMessageConverter.class); given(messageConverter.getSupportedMediaTypes()).willReturn(Collections.singletonList(MediaType.TEXT_PLAIN)); @@ -129,6 +137,7 @@ public class RequestPartMethodArgumentResolverTests { webRequest = new ServletWebRequest(multipartRequest, servletResponse); } + @Test public void supportsParameter() { assertTrue("RequestPart parameter not supported", resolver.supportsParameter(paramRequestPart)); @@ -139,8 +148,8 @@ public class RequestPartMethodArgumentResolverTests { assertTrue("MultipartFile parameter not supported", resolver.supportsParameter(paramMultipartFile)); assertTrue("List parameter not supported", resolver.supportsParameter(paramMultipartFileList)); assertTrue("MultipartFile[] parameter not supported", resolver.supportsParameter(paramMultipartFileArray)); - assertFalse("non-RequestPart parameter supported", resolver.supportsParameter(paramInt)); - assertFalse("@RequestParam args not supported", resolver.supportsParameter(paramRequestParamAnnot)); + assertFalse("non-RequestPart parameter should not be supported", resolver.supportsParameter(paramInt)); + assertFalse("@RequestParam args should not be supported", resolver.supportsParameter(paramRequestParamAnnot)); } @Test @@ -243,21 +252,27 @@ public class RequestPartMethodArgumentResolverTests { testResolveArgument(new SimpleBean("foo"), paramNamedRequestPart); } + @Test + public void resolveNamedRequestPartNotPresent() throws Exception { + testResolveArgument(null, paramNamedRequestPart); + } + @Test public void resolveRequestPartNotValid() throws Exception { try { testResolveArgument(new SimpleBean(null), paramValidRequestPart); fail("Expected exception"); - } catch (MethodArgumentNotValidException e) { - assertEquals("requestPart", e.getBindingResult().getObjectName()); - assertEquals(1, e.getBindingResult().getErrorCount()); - assertNotNull(e.getBindingResult().getFieldError("name")); + } + catch (MethodArgumentNotValidException ex) { + assertEquals("requestPart", ex.getBindingResult().getObjectName()); + assertEquals(1, ex.getBindingResult().getErrorCount()); + assertNotNull(ex.getBindingResult().getFieldError("name")); } } @Test public void resolveRequestPartValid() throws Exception { - testResolveArgument(new SimpleBean("foo"), paramNamedRequestPart); + testResolveArgument(new SimpleBean("foo"), paramValidRequestPart); } @Test @@ -265,8 +280,9 @@ public class RequestPartMethodArgumentResolverTests { try { testResolveArgument(null, paramValidRequestPart); fail("Expected exception"); - } catch (MissingServletRequestPartException e) { - assertEquals("requestPart", e.getRequestPartName()); + } + catch (MissingServletRequestPartException ex) { + assertEquals("requestPart", ex.getRequestPartName()); } } @@ -275,27 +291,100 @@ public class RequestPartMethodArgumentResolverTests { testResolveArgument(new SimpleBean("foo"), paramValidRequestPart); } - @Test(expected=MultipartException.class) + @Test(expected = MultipartException.class) public void isMultipartRequest() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); resolver.resolveArgument(paramMultipartFile, new ModelAndViewContainer(), new ServletWebRequest(request), null); - fail("Expected exception"); } - // SPR-9079 - - @Test + @Test // SPR-9079 public void isMultipartRequestPut() throws Exception { this.multipartRequest.setMethod("PUT"); Object actual = resolver.resolveArgument(paramMultipartFile, null, webRequest, null); - assertNotNull(actual); assertSame(multipartFile1, actual); } - private void testResolveArgument(SimpleBean argValue, MethodParameter parameter) throws IOException, Exception { - MediaType contentType = MediaType.TEXT_PLAIN; + @Test + public void resolveOptionalMultipartFileArgument() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + MultipartFile expected = new MockMultipartFile("optionalMultipartFile", "Hello World".getBytes()); + request.addFile(expected); + webRequest = new ServletWebRequest(request); - given(messageConverter.canRead(SimpleBean.class, contentType)).willReturn(true); + Object result = resolver.resolveArgument(optionalMultipartFile, null, webRequest, null); + + assertTrue(result instanceof Optional); + assertEquals("Invalid result", expected, ((Optional) result).get()); + } + + @Test + public void resolveOptionalMultipartFileArgumentNotPresent() throws Exception { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(); + webRequest = new ServletWebRequest(request); + + Object result = resolver.resolveArgument(optionalMultipartFile, null, webRequest, null); + + assertTrue(result instanceof Optional); + assertFalse("Invalid result", ((Optional) result).isPresent()); + } + + @Test + public void resolveOptionalPartArgument() throws Exception { + MockPart expected = new MockPart("optionalPart", "Hello World".getBytes()); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("POST"); + request.setContentType("multipart/form-data"); + request.addPart(expected); + webRequest = new ServletWebRequest(request); + + Object result = resolver.resolveArgument(optionalPart, null, webRequest, null); + + assertTrue(result instanceof Optional); + assertEquals("Invalid result", expected, ((Optional) result).get()); + } + + @Test + public void resolveOptionalPartArgumentNotPresent() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("POST"); + request.setContentType("multipart/form-data"); + webRequest = new ServletWebRequest(request); + + Object result = resolver.resolveArgument(optionalPart, null, webRequest, null); + + assertTrue(result instanceof Optional); + assertFalse("Invalid result", ((Optional) result).isPresent()); + } + + @Test + public void resolveOptionalRequestPart() throws Exception { + SimpleBean simpleBean = new SimpleBean("foo"); + + given(messageConverter.canRead(SimpleBean.class, MediaType.TEXT_PLAIN)).willReturn(true); + given(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).willReturn(simpleBean); + + ModelAndViewContainer mavContainer = new ModelAndViewContainer(); + Object actualValue = resolver.resolveArgument(optionalRequestPart, mavContainer, webRequest, new ValidatingBinderFactory()); + + assertEquals("Invalid argument value", Optional.of(simpleBean), actualValue); + assertFalse("The requestHandled flag shouldn't change", mavContainer.isRequestHandled()); + } + + @Test + public void resolveOptionalRequestPartNotPresent() throws Exception { + given(messageConverter.canRead(SimpleBean.class, MediaType.TEXT_PLAIN)).willReturn(true); + given(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).willReturn(null); + + ModelAndViewContainer mavContainer = new ModelAndViewContainer(); + Object actualValue = resolver.resolveArgument(optionalRequestPart, mavContainer, webRequest, new ValidatingBinderFactory()); + + assertEquals("Invalid argument value", Optional.empty(), actualValue); + assertFalse("The requestHandled flag shouldn't change", mavContainer.isRequestHandled()); + } + + + private void testResolveArgument(SimpleBean argValue, MethodParameter parameter) throws Exception { + given(messageConverter.canRead(SimpleBean.class, MediaType.TEXT_PLAIN)).willReturn(true); given(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).willReturn(argValue); ModelAndViewContainer mavContainer = new ModelAndViewContainer(); @@ -305,6 +394,7 @@ public class RequestPartMethodArgumentResolverTests { assertFalse("The requestHandled flag shouldn't change", mavContainer.isRequestHandled()); } + private static class SimpleBean { @NotNull @@ -320,7 +410,9 @@ public class RequestPartMethodArgumentResolverTests { } } + private final class ValidatingBinderFactory implements WebDataBinderFactory { + @Override public WebDataBinder createBinder(NativeWebRequest webRequest, Object target, String objectName) throws Exception { LocalValidatorFactoryBean validator = new LocalValidatorFactoryBean(); @@ -331,6 +423,7 @@ public class RequestPartMethodArgumentResolverTests { } } + public void handle(@RequestPart SimpleBean requestPart, @RequestPart(value="requestPart", required=false) SimpleBean namedRequestPart, @Valid @RequestPart("requestPart") SimpleBean validRequestPart, @@ -342,7 +435,10 @@ public class RequestPartMethodArgumentResolverTests { Part part, @RequestPart("part") List partList, @RequestPart("part") Part[] partArray, - @RequestParam MultipartFile requestParamAnnot) { + @RequestParam MultipartFile requestParamAnnot, + Optional optionalMultipartFile, + Optional optionalPart, + @RequestPart("requestPart") Optional optionalRequestPart) { } }