@RequestPart supports java.util.Optional

Issue: SPR-12644
This commit is contained in:
Juergen Hoeller 2015-02-18 16:17:07 +01:00
parent 61cc3b5bff
commit 6ebac00f32
4 changed files with 212 additions and 65 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2014 the original author or authors. * Copyright 2002-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -309,6 +309,7 @@ public class MethodParameter {
/** /**
* Return the generic type of the method/constructor parameter. * Return the generic type of the method/constructor parameter.
* @return the parameter type (never {@code null}) * @return the parameter type (never {@code null})
* @since 3.0
*/ */
public Type getGenericParameterType() { public Type getGenericParameterType() {
if (this.genericParameterType == null) { if (this.genericParameterType == null) {
@ -324,6 +325,12 @@ public class MethodParameter {
return this.genericParameterType; return this.genericParameterType;
} }
/**
* Return the nested type of the method/constructor parameter.
* @return the parameter type (never {@code null})
* @see #getNestingLevel()
* @since 3.1
*/
public Class<?> getNestedParameterType() { public Class<?> getNestedParameterType() {
if (this.nestingLevel > 1) { if (this.nestingLevel > 1) {
Type type = getGenericParameterType(); Type type = getGenericParameterType();
@ -350,6 +357,29 @@ public class MethodParameter {
} }
} }
/**
* Return the nested generic type of the method/constructor parameter.
* @return the parameter type (never {@code null})
* @see #getNestingLevel()
* @since 4.2
*/
public Type getNestedGenericParameterType() {
if (this.nestingLevel > 1) {
Type type = getGenericParameterType();
for (int i = 2; i <= this.nestingLevel; i++) {
if (type instanceof ParameterizedType) {
Type[] args = ((ParameterizedType) type).getActualTypeArguments();
Integer index = getTypeIndexForLevel(i);
type = args[index != null ? index : args.length - 1];
}
}
return type;
}
else {
return getGenericParameterType();
}
}
/** /**
* Return the annotations associated with the target method/constructor itself. * Return the annotations associated with the target method/constructor itself.
*/ */

View File

@ -104,7 +104,7 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
* from the given HttpInputMessage. * from the given HttpInputMessage.
* @param <T> the expected type of the argument value to be created * @param <T> the expected type of the argument value to be created
* @param inputMessage the HTTP input message representing the current request * @param inputMessage the HTTP input message representing the current request
* @param methodParam the method argument * @param methodParam the method parameter descriptor (may be {@code null})
* @param targetType the type of object to create, not necessarily the same as * @param targetType the type of object to create, not necessarily the same as
* the method parameter type (e.g. for {@code HttpEntity<String>} method * the method parameter type (e.g. for {@code HttpEntity<String>} method
* parameter the target type is String) * parameter the target type is String)
@ -113,8 +113,8 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
* @throws HttpMediaTypeNotSupportedException if no suitable message converter is found * @throws HttpMediaTypeNotSupportedException if no suitable message converter is found
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
protected <T> Object readWithMessageConverters(HttpInputMessage inputMessage, protected <T> Object readWithMessageConverters(HttpInputMessage inputMessage, MethodParameter methodParam,
MethodParameter methodParam, Type targetType) throws IOException, HttpMediaTypeNotSupportedException { Type targetType) throws IOException, HttpMediaTypeNotSupportedException {
MediaType contentType; MediaType contentType;
try { try {
@ -127,7 +127,13 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
contentType = MediaType.APPLICATION_OCTET_STREAM; contentType = MediaType.APPLICATION_OCTET_STREAM;
} }
Class<?> contextClass = methodParam.getContainingClass(); Class<?> contextClass = (methodParam != null ? methodParam.getContainingClass() : null);
Class<T> targetClass = (targetType instanceof Class<?> ? (Class<T>) targetType : null);
if (targetClass == null) {
ResolvableType resolvableType = (methodParam != null ?
ResolvableType.forMethodParameter(methodParam) : ResolvableType.forType(targetType));
targetClass = (Class<T>) resolvableType.resolve();
}
for (HttpMessageConverter<?> converter : this.messageConverters) { for (HttpMessageConverter<?> converter : this.messageConverters) {
if (converter instanceof GenericHttpMessageConverter) { if (converter instanceof GenericHttpMessageConverter) {
@ -140,14 +146,14 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements
return genericConverter.read(targetType, contextClass, inputMessage); return genericConverter.read(targetType, contextClass, inputMessage);
} }
} }
Class<T> targetClass = (Class<T>) else if (targetClass != null) {
ResolvableType.forMethodParameter(methodParam, targetType).resolve(Object.class); if (converter.canRead(targetClass, contentType)) {
if (converter.canRead(targetClass, contentType)) { if (logger.isDebugEnabled()) {
if (logger.isDebugEnabled()) { logger.debug("Reading [" + targetClass.getName() + "] as \"" +
logger.debug("Reading [" + targetClass.getName() + "] as \"" + contentType + "\" using [" + converter + "]");
contentType + "\" using [" + converter + "]"); }
return ((HttpMessageConverter<T>) converter).read(targetClass, inputMessage);
} }
return ((HttpMessageConverter<T>) converter).read(targetClass, inputMessage);
} }
} }

View File

@ -20,6 +20,7 @@ import java.lang.annotation.Annotation;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Optional;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.Part; import javax.servlet.http.Part;
@ -80,12 +81,13 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
super(messageConverters); super(messageConverters);
} }
/** /**
* Supports the following: * Supports the following:
* <ul> * <ul>
* <li>Annotated with {@code @RequestPart} * <li>Annotated with {@code @RequestPart}
* <li>Of type {@link MultipartFile} unless annotated with {@code @RequestParam}. * <li>Of type {@link MultipartFile} unless annotated with {@code @RequestParam}.
* <li>Of type {@code javax.servlet.http.Part} unless annotated with {@code @RequestParam}. * <li>Of type {@code javax.servlet.http.Part} unless annotated with {@code @RequestParam}.
* </ul> * </ul>
*/ */
@Override @Override
@ -117,12 +119,19 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
assertIsMultipartRequest(servletRequest); assertIsMultipartRequest(servletRequest);
MultipartHttpServletRequest multipartRequest = MultipartHttpServletRequest multipartRequest =
WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class); WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class);
Class<?> paramType = parameter.getParameterType();
boolean optional = paramType.getName().equals("java.util.Optional");
if (optional) {
parameter.increaseNestingLevel();
paramType = parameter.getNestedParameterType();
}
String partName = getPartName(parameter); String partName = getPartName(parameter);
Object arg; Object arg;
if (MultipartFile.class.equals(parameter.getParameterType())) { if (MultipartFile.class.equals(paramType)) {
Assert.notNull(multipartRequest, "Expected MultipartHttpServletRequest: is a MultipartResolver configured?"); Assert.notNull(multipartRequest, "Expected MultipartHttpServletRequest: is a MultipartResolver configured?");
arg = multipartRequest.getFile(partName); arg = multipartRequest.getFile(partName);
} }
@ -135,7 +144,7 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
List<MultipartFile> files = multipartRequest.getFiles(partName); List<MultipartFile> files = multipartRequest.getFiles(partName);
arg = files.toArray(new MultipartFile[files.size()]); arg = files.toArray(new MultipartFile[files.size()]);
} }
else if ("javax.servlet.http.Part".equals(parameter.getParameterType().getName())) { else if ("javax.servlet.http.Part".equals(paramType.getName())) {
assertIsMultipartRequest(servletRequest); assertIsMultipartRequest(servletRequest);
arg = servletRequest.getPart(partName); arg = servletRequest.getPart(partName);
} }
@ -150,7 +159,7 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
else { else {
try { try {
HttpInputMessage inputMessage = new RequestPartServletServerHttpRequest(servletRequest, partName); HttpInputMessage inputMessage = new RequestPartServletServerHttpRequest(servletRequest, partName);
arg = readWithMessageConverters(inputMessage, parameter, parameter.getParameterType()); arg = readWithMessageConverters(inputMessage, parameter, parameter.getNestedGenericParameterType());
WebDataBinder binder = binderFactory.createBinder(request, arg, partName); WebDataBinder binder = binderFactory.createBinder(request, arg, partName);
if (arg != null) { if (arg != null) {
validate(binder, parameter); validate(binder, parameter);
@ -164,11 +173,14 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
} }
RequestPart annot = parameter.getParameterAnnotation(RequestPart.class); RequestPart annot = parameter.getParameterAnnotation(RequestPart.class);
boolean isRequired = (annot == null || annot.required()); boolean isRequired = ((annot == null || annot.required()) && !optional);
if (arg == null && isRequired) { if (arg == null && isRequired) {
throw new MissingServletRequestPartException(partName); throw new MissingServletRequestPartException(partName);
} }
if (optional) {
arg = Optional.ofNullable(arg);
}
return arg; return arg;
} }
@ -180,41 +192,44 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
} }
} }
private String getPartName(MethodParameter parameter) { private String getPartName(MethodParameter param) {
RequestPart annot = parameter.getParameterAnnotation(RequestPart.class); RequestPart annot = param.getParameterAnnotation(RequestPart.class);
String partName = (annot != null ? annot.value() : ""); String partName = (annot != null ? annot.value() : "");
if (partName.length() == 0) { if (partName.length() == 0) {
partName = parameter.getParameterName(); partName = param.getParameterName();
Assert.notNull(partName, "Request part name for argument type [" + parameter.getParameterType().getName() + if (partName == null) {
"] not specified, and parameter name information not found in class file either."); throw new IllegalArgumentException("Request part name for argument type [" +
param.getNestedParameterType().getName() +
"] not specified, and parameter name information not found in class file either.");
}
} }
return partName; return partName;
} }
private boolean isMultipartFileCollection(MethodParameter parameter) { private boolean isMultipartFileCollection(MethodParameter param) {
Class<?> collectionType = getCollectionParameterType(parameter); Class<?> collectionType = getCollectionParameterType(param);
return (collectionType != null && collectionType.equals(MultipartFile.class)); return (collectionType != null && collectionType.equals(MultipartFile.class));
} }
private boolean isMultipartFileArray(MethodParameter parameter) { private boolean isMultipartFileArray(MethodParameter param) {
Class<?> paramType = parameter.getParameterType().getComponentType(); Class<?> paramType = param.getNestedParameterType().getComponentType();
return (paramType != null && MultipartFile.class.equals(paramType)); return (paramType != null && MultipartFile.class.equals(paramType));
} }
private boolean isPartCollection(MethodParameter parameter) { private boolean isPartCollection(MethodParameter param) {
Class<?> collectionType = getCollectionParameterType(parameter); Class<?> collectionType = getCollectionParameterType(param);
return (collectionType != null && "javax.servlet.http.Part".equals(collectionType.getName())); return (collectionType != null && "javax.servlet.http.Part".equals(collectionType.getName()));
} }
private boolean isPartArray(MethodParameter parameter) { private boolean isPartArray(MethodParameter param) {
Class<?> paramType = parameter.getParameterType().getComponentType(); Class<?> paramType = param.getNestedParameterType().getComponentType();
return (paramType != null && "javax.servlet.http.Part".equals(paramType.getName())); return (paramType != null && "javax.servlet.http.Part".equals(paramType.getName()));
} }
private Class<?> getCollectionParameterType(MethodParameter parameter) { private Class<?> getCollectionParameterType(MethodParameter param) {
Class<?> paramType = parameter.getParameterType(); Class<?> paramType = param.getNestedParameterType();
if (Collection.class.equals(paramType) || List.class.isAssignableFrom(paramType)){ if (Collection.class.equals(paramType) || List.class.isAssignableFrom(paramType)){
Class<?> valueType = GenericCollectionTypeResolver.getCollectionParameterType(parameter); Class<?> valueType = GenericCollectionTypeResolver.getCollectionParameterType(param);
if (valueType != null) { if (valueType != null) {
return valueType; return valueType;
} }
@ -228,13 +243,13 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
* Spring's {@link org.springframework.validation.annotation.Validated}, * Spring's {@link org.springframework.validation.annotation.Validated},
* and custom annotations whose name starts with "Valid". * and custom annotations whose name starts with "Valid".
* @param binder the DataBinder to be used * @param binder the DataBinder to be used
* @param parameter the method parameter * @param param the method parameter
* @throws MethodArgumentNotValidException in case of a binding error which * @throws MethodArgumentNotValidException in case of a binding error which
* is meant to be fatal (i.e. without a declared {@link Errors} parameter) * is meant to be fatal (i.e. without a declared {@link Errors} parameter)
* @see #isBindingErrorFatal * @see #isBindingErrorFatal
*/ */
protected void validate(WebDataBinder binder, MethodParameter parameter) throws MethodArgumentNotValidException { protected void validate(WebDataBinder binder, MethodParameter param) throws MethodArgumentNotValidException {
Annotation[] annotations = parameter.getParameterAnnotations(); Annotation[] annotations = param.getParameterAnnotations();
for (Annotation ann : annotations) { for (Annotation ann : annotations) {
Validated validatedAnn = AnnotationUtils.getAnnotation(ann, Validated.class); Validated validatedAnn = AnnotationUtils.getAnnotation(ann, Validated.class);
if (validatedAnn != null || ann.annotationType().getSimpleName().startsWith("Valid")) { if (validatedAnn != null || ann.annotationType().getSimpleName().startsWith("Valid")) {
@ -243,8 +258,8 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
binder.validate(validationHints); binder.validate(validationHints);
BindingResult bindingResult = binder.getBindingResult(); BindingResult bindingResult = binder.getBindingResult();
if (bindingResult.hasErrors()) { if (bindingResult.hasErrors()) {
if (isBindingErrorFatal(parameter)) { if (isBindingErrorFatal(param)) {
throw new MethodArgumentNotValidException(parameter, bindingResult); throw new MethodArgumentNotValidException(param, bindingResult);
} }
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2014 the original author or authors. * Copyright 2002-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,11 +16,11 @@
package org.springframework.web.servlet.mvc.method.annotation; package org.springframework.web.servlet.mvc.method.annotation;
import java.io.IOException;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
import javax.servlet.http.Part; import javax.servlet.http.Part;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
@ -82,6 +82,9 @@ public class RequestPartMethodArgumentResolverTests {
private MethodParameter paramPartList; private MethodParameter paramPartList;
private MethodParameter paramPartArray; private MethodParameter paramPartArray;
private MethodParameter paramRequestParamAnnot; private MethodParameter paramRequestParamAnnot;
private MethodParameter optionalMultipartFile;
private MethodParameter optionalPart;
private MethodParameter optionalRequestPart;
private NativeWebRequest webRequest; private NativeWebRequest webRequest;
@ -89,14 +92,14 @@ public class RequestPartMethodArgumentResolverTests {
private MockHttpServletResponse servletResponse; private MockHttpServletResponse servletResponse;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
Method method = getClass().getMethod("handle", SimpleBean.class, SimpleBean.class, Method method = getClass().getMethod("handle", SimpleBean.class, SimpleBean.class,
SimpleBean.class, MultipartFile.class, List.class, MultipartFile[].class, SimpleBean.class, MultipartFile.class, List.class, MultipartFile[].class,
Integer.TYPE, MultipartFile.class, Part.class, List.class, Integer.TYPE, MultipartFile.class, Part.class, List.class, Part[].class,
Part[].class, MultipartFile.class); MultipartFile.class, Optional.class, Optional.class, Optional.class);
paramRequestPart = new MethodParameter(method, 0); paramRequestPart = new MethodParameter(method, 0);
paramRequestPart.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); paramRequestPart.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
@ -113,6 +116,11 @@ public class RequestPartMethodArgumentResolverTests {
paramPartList = new MethodParameter(method, 9); paramPartList = new MethodParameter(method, 9);
paramPartArray = new MethodParameter(method, 10); paramPartArray = new MethodParameter(method, 10);
paramRequestParamAnnot = new MethodParameter(method, 11); paramRequestParamAnnot = new MethodParameter(method, 11);
optionalMultipartFile = new MethodParameter(method, 12);
optionalMultipartFile.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
optionalPart = new MethodParameter(method, 13);
optionalPart.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
optionalRequestPart = new MethodParameter(method, 14);
messageConverter = mock(HttpMessageConverter.class); messageConverter = mock(HttpMessageConverter.class);
given(messageConverter.getSupportedMediaTypes()).willReturn(Collections.singletonList(MediaType.TEXT_PLAIN)); given(messageConverter.getSupportedMediaTypes()).willReturn(Collections.singletonList(MediaType.TEXT_PLAIN));
@ -129,6 +137,7 @@ public class RequestPartMethodArgumentResolverTests {
webRequest = new ServletWebRequest(multipartRequest, servletResponse); webRequest = new ServletWebRequest(multipartRequest, servletResponse);
} }
@Test @Test
public void supportsParameter() { public void supportsParameter() {
assertTrue("RequestPart parameter not supported", resolver.supportsParameter(paramRequestPart)); assertTrue("RequestPart parameter not supported", resolver.supportsParameter(paramRequestPart));
@ -139,8 +148,8 @@ public class RequestPartMethodArgumentResolverTests {
assertTrue("MultipartFile parameter not supported", resolver.supportsParameter(paramMultipartFile)); assertTrue("MultipartFile parameter not supported", resolver.supportsParameter(paramMultipartFile));
assertTrue("List<MultipartFile> parameter not supported", resolver.supportsParameter(paramMultipartFileList)); assertTrue("List<MultipartFile> parameter not supported", resolver.supportsParameter(paramMultipartFileList));
assertTrue("MultipartFile[] parameter not supported", resolver.supportsParameter(paramMultipartFileArray)); assertTrue("MultipartFile[] parameter not supported", resolver.supportsParameter(paramMultipartFileArray));
assertFalse("non-RequestPart parameter supported", resolver.supportsParameter(paramInt)); assertFalse("non-RequestPart parameter should not be supported", resolver.supportsParameter(paramInt));
assertFalse("@RequestParam args not supported", resolver.supportsParameter(paramRequestParamAnnot)); assertFalse("@RequestParam args should not be supported", resolver.supportsParameter(paramRequestParamAnnot));
} }
@Test @Test
@ -243,21 +252,27 @@ public class RequestPartMethodArgumentResolverTests {
testResolveArgument(new SimpleBean("foo"), paramNamedRequestPart); testResolveArgument(new SimpleBean("foo"), paramNamedRequestPart);
} }
@Test
public void resolveNamedRequestPartNotPresent() throws Exception {
testResolveArgument(null, paramNamedRequestPart);
}
@Test @Test
public void resolveRequestPartNotValid() throws Exception { public void resolveRequestPartNotValid() throws Exception {
try { try {
testResolveArgument(new SimpleBean(null), paramValidRequestPart); testResolveArgument(new SimpleBean(null), paramValidRequestPart);
fail("Expected exception"); fail("Expected exception");
} catch (MethodArgumentNotValidException e) { }
assertEquals("requestPart", e.getBindingResult().getObjectName()); catch (MethodArgumentNotValidException ex) {
assertEquals(1, e.getBindingResult().getErrorCount()); assertEquals("requestPart", ex.getBindingResult().getObjectName());
assertNotNull(e.getBindingResult().getFieldError("name")); assertEquals(1, ex.getBindingResult().getErrorCount());
assertNotNull(ex.getBindingResult().getFieldError("name"));
} }
} }
@Test @Test
public void resolveRequestPartValid() throws Exception { public void resolveRequestPartValid() throws Exception {
testResolveArgument(new SimpleBean("foo"), paramNamedRequestPart); testResolveArgument(new SimpleBean("foo"), paramValidRequestPart);
} }
@Test @Test
@ -265,8 +280,9 @@ public class RequestPartMethodArgumentResolverTests {
try { try {
testResolveArgument(null, paramValidRequestPart); testResolveArgument(null, paramValidRequestPart);
fail("Expected exception"); fail("Expected exception");
} catch (MissingServletRequestPartException e) { }
assertEquals("requestPart", e.getRequestPartName()); catch (MissingServletRequestPartException ex) {
assertEquals("requestPart", ex.getRequestPartName());
} }
} }
@ -275,27 +291,100 @@ public class RequestPartMethodArgumentResolverTests {
testResolveArgument(new SimpleBean("foo"), paramValidRequestPart); testResolveArgument(new SimpleBean("foo"), paramValidRequestPart);
} }
@Test(expected=MultipartException.class) @Test(expected = MultipartException.class)
public void isMultipartRequest() throws Exception { public void isMultipartRequest() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
resolver.resolveArgument(paramMultipartFile, new ModelAndViewContainer(), new ServletWebRequest(request), null); resolver.resolveArgument(paramMultipartFile, new ModelAndViewContainer(), new ServletWebRequest(request), null);
fail("Expected exception");
} }
// SPR-9079 @Test // SPR-9079
@Test
public void isMultipartRequestPut() throws Exception { public void isMultipartRequestPut() throws Exception {
this.multipartRequest.setMethod("PUT"); this.multipartRequest.setMethod("PUT");
Object actual = resolver.resolveArgument(paramMultipartFile, null, webRequest, null); Object actual = resolver.resolveArgument(paramMultipartFile, null, webRequest, null);
assertNotNull(actual);
assertSame(multipartFile1, actual); assertSame(multipartFile1, actual);
} }
private void testResolveArgument(SimpleBean argValue, MethodParameter parameter) throws IOException, Exception { @Test
MediaType contentType = MediaType.TEXT_PLAIN; public void resolveOptionalMultipartFileArgument() throws Exception {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
MultipartFile expected = new MockMultipartFile("optionalMultipartFile", "Hello World".getBytes());
request.addFile(expected);
webRequest = new ServletWebRequest(request);
given(messageConverter.canRead(SimpleBean.class, contentType)).willReturn(true); Object result = resolver.resolveArgument(optionalMultipartFile, null, webRequest, null);
assertTrue(result instanceof Optional);
assertEquals("Invalid result", expected, ((Optional) result).get());
}
@Test
public void resolveOptionalMultipartFileArgumentNotPresent() throws Exception {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest();
webRequest = new ServletWebRequest(request);
Object result = resolver.resolveArgument(optionalMultipartFile, null, webRequest, null);
assertTrue(result instanceof Optional);
assertFalse("Invalid result", ((Optional) result).isPresent());
}
@Test
public void resolveOptionalPartArgument() throws Exception {
MockPart expected = new MockPart("optionalPart", "Hello World".getBytes());
MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("POST");
request.setContentType("multipart/form-data");
request.addPart(expected);
webRequest = new ServletWebRequest(request);
Object result = resolver.resolveArgument(optionalPart, null, webRequest, null);
assertTrue(result instanceof Optional);
assertEquals("Invalid result", expected, ((Optional) result).get());
}
@Test
public void resolveOptionalPartArgumentNotPresent() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("POST");
request.setContentType("multipart/form-data");
webRequest = new ServletWebRequest(request);
Object result = resolver.resolveArgument(optionalPart, null, webRequest, null);
assertTrue(result instanceof Optional);
assertFalse("Invalid result", ((Optional) result).isPresent());
}
@Test
public void resolveOptionalRequestPart() throws Exception {
SimpleBean simpleBean = new SimpleBean("foo");
given(messageConverter.canRead(SimpleBean.class, MediaType.TEXT_PLAIN)).willReturn(true);
given(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).willReturn(simpleBean);
ModelAndViewContainer mavContainer = new ModelAndViewContainer();
Object actualValue = resolver.resolveArgument(optionalRequestPart, mavContainer, webRequest, new ValidatingBinderFactory());
assertEquals("Invalid argument value", Optional.of(simpleBean), actualValue);
assertFalse("The requestHandled flag shouldn't change", mavContainer.isRequestHandled());
}
@Test
public void resolveOptionalRequestPartNotPresent() throws Exception {
given(messageConverter.canRead(SimpleBean.class, MediaType.TEXT_PLAIN)).willReturn(true);
given(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).willReturn(null);
ModelAndViewContainer mavContainer = new ModelAndViewContainer();
Object actualValue = resolver.resolveArgument(optionalRequestPart, mavContainer, webRequest, new ValidatingBinderFactory());
assertEquals("Invalid argument value", Optional.empty(), actualValue);
assertFalse("The requestHandled flag shouldn't change", mavContainer.isRequestHandled());
}
private void testResolveArgument(SimpleBean argValue, MethodParameter parameter) throws Exception {
given(messageConverter.canRead(SimpleBean.class, MediaType.TEXT_PLAIN)).willReturn(true);
given(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).willReturn(argValue); given(messageConverter.read(eq(SimpleBean.class), isA(RequestPartServletServerHttpRequest.class))).willReturn(argValue);
ModelAndViewContainer mavContainer = new ModelAndViewContainer(); ModelAndViewContainer mavContainer = new ModelAndViewContainer();
@ -305,6 +394,7 @@ public class RequestPartMethodArgumentResolverTests {
assertFalse("The requestHandled flag shouldn't change", mavContainer.isRequestHandled()); assertFalse("The requestHandled flag shouldn't change", mavContainer.isRequestHandled());
} }
private static class SimpleBean { private static class SimpleBean {
@NotNull @NotNull
@ -320,7 +410,9 @@ public class RequestPartMethodArgumentResolverTests {
} }
} }
private final class ValidatingBinderFactory implements WebDataBinderFactory { private final class ValidatingBinderFactory implements WebDataBinderFactory {
@Override @Override
public WebDataBinder createBinder(NativeWebRequest webRequest, Object target, String objectName) throws Exception { public WebDataBinder createBinder(NativeWebRequest webRequest, Object target, String objectName) throws Exception {
LocalValidatorFactoryBean validator = new LocalValidatorFactoryBean(); LocalValidatorFactoryBean validator = new LocalValidatorFactoryBean();
@ -331,6 +423,7 @@ public class RequestPartMethodArgumentResolverTests {
} }
} }
public void handle(@RequestPart SimpleBean requestPart, public void handle(@RequestPart SimpleBean requestPart,
@RequestPart(value="requestPart", required=false) SimpleBean namedRequestPart, @RequestPart(value="requestPart", required=false) SimpleBean namedRequestPart,
@Valid @RequestPart("requestPart") SimpleBean validRequestPart, @Valid @RequestPart("requestPart") SimpleBean validRequestPart,
@ -342,7 +435,10 @@ public class RequestPartMethodArgumentResolverTests {
Part part, Part part,
@RequestPart("part") List<Part> partList, @RequestPart("part") List<Part> partList,
@RequestPart("part") Part[] partArray, @RequestPart("part") Part[] partArray,
@RequestParam MultipartFile requestParamAnnot) { @RequestParam MultipartFile requestParamAnnot,
Optional<MultipartFile> optionalMultipartFile,
Optional<Part> optionalPart,
@RequestPart("requestPart") Optional<SimpleBean> optionalRequestPart) {
} }
} }