diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java index f9607f5bb2..15153dcea3 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.mock.web; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.Iterator; @@ -33,6 +34,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartHttpServletRequest; @@ -155,15 +157,28 @@ public class MockMultipartHttpServletRequest extends MockHttpServletRequest impl @Override public HttpHeaders getMultipartHeaders(String paramOrFileName) { - String contentType = getMultipartContentType(paramOrFileName); - if (contentType != null) { + MultipartFile file = getFile(paramOrFileName); + if (file != null) { HttpHeaders headers = new HttpHeaders(); - headers.add(HttpHeaders.CONTENT_TYPE, contentType); + if (file.getContentType() != null) { + headers.add(HttpHeaders.CONTENT_TYPE, file.getContentType()); + } return headers; } - else { - return null; + try { + Part part = getPart(paramOrFileName); + if (part != null) { + HttpHeaders headers = new HttpHeaders(); + for (String headerName : part.getHeaderNames()) { + headers.put(headerName, new ArrayList<>(part.getHeaders(headerName))); + } + return headers; + } } + catch (Throwable ex) { + throw new MultipartException("Could not access multipart servlet request", ex); + } + return null; } } diff --git a/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java index c083b42fdc..7788b1e704 100644 --- a/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java +++ b/spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2011 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,9 +60,10 @@ public interface MultipartHttpServletRequest extends HttpServletRequest, Multipa HttpHeaders getRequestHeaders(); /** - * Return the headers associated with the specified part of the multipart request. - *
If the underlying implementation supports access to headers, then all headers are returned. - * Otherwise, the returned headers will include a 'Content-Type' header at the very least. + * Return the headers for the specified part of the multipart request. + *
If the underlying implementation supports access to part headers,
+ * then all headers are returned. Otherwise, e.g. for a file upload, the
+ * returned headers may expose a 'Content-Type' if available.
*/
@Nullable
HttpHeaders getMultipartHeaders(String paramOrFileName);
diff --git a/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockMultipartHttpServletRequest.java b/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockMultipartHttpServletRequest.java
index ace2e125c5..c36eb3ac7b 100644
--- a/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockMultipartHttpServletRequest.java
+++ b/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockMultipartHttpServletRequest.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
package org.springframework.web.testfixture.servlet;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
@@ -33,6 +34,7 @@ import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
+import org.springframework.web.multipart.MultipartException;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;
@@ -155,15 +157,28 @@ public class MockMultipartHttpServletRequest extends MockHttpServletRequest impl
@Override
public HttpHeaders getMultipartHeaders(String paramOrFileName) {
- String contentType = getMultipartContentType(paramOrFileName);
- if (contentType != null) {
+ MultipartFile file = getFile(paramOrFileName);
+ if (file != null) {
HttpHeaders headers = new HttpHeaders();
- headers.add(HttpHeaders.CONTENT_TYPE, contentType);
+ if (file.getContentType() != null) {
+ headers.add(HttpHeaders.CONTENT_TYPE, file.getContentType());
+ }
return headers;
}
- else {
- return null;
+ try {
+ Part part = getPart(paramOrFileName);
+ if (part != null) {
+ HttpHeaders headers = new HttpHeaders();
+ for (String headerName : part.getHeaderNames()) {
+ headers.put(headerName, new ArrayList<>(part.getHeaders(headerName)));
+ }
+ return headers;
+ }
}
+ catch (Throwable ex) {
+ throw new MultipartException("Could not access multipart servlet request", ex);
+ }
+ return null;
}
}
diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java
index f35a3817f5..9cf34618a6 100644
--- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java
+++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -36,6 +36,7 @@ import org.springframework.core.annotation.SynthesizingMethodParameter;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.lang.Nullable;
import org.springframework.util.ReflectionUtils;
import org.springframework.validation.BindingResult;
@@ -51,6 +52,7 @@ import org.springframework.web.method.support.ModelAndViewContainer;
import org.springframework.web.multipart.MultipartException;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.support.MissingServletRequestPartException;
+import org.springframework.web.testfixture.method.ResolvableMethod;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
import org.springframework.web.testfixture.servlet.MockMultipartFile;
@@ -311,6 +313,22 @@ public class RequestPartMethodArgumentResolverTests {
testResolveArgument(new SimpleBean("foo"), paramValidRequestPart);
}
+ @Test // gh-26501
+ public void resolveRequestPartWithoutContentType() throws Exception {
+ MockMultipartHttpServletRequest servletRequest = new MockMultipartHttpServletRequest();
+ servletRequest.addPart(new MockPart("requestPartString", "part value".getBytes(StandardCharsets.UTF_8)));
+ ServletWebRequest webRequest = new ServletWebRequest(servletRequest, new MockHttpServletResponse());
+
+ List> optionalMultipartFileList,
Optional
> optionalPartList,
- @RequestPart("requestPart") Optional