Restrict fallback multipart binding to POST requests with multipart/form-data

Closes gh-26999
See gh-26826
This commit is contained in:
Juergen Hoeller 2021-07-12 17:55:49 +02:00
parent 128689e79b
commit ed27ea7aa0
4 changed files with 22 additions and 12 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -20,6 +20,8 @@ import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.MutablePropertyValues;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.validation.BindException; import org.springframework.validation.BindException;
@ -106,9 +108,9 @@ public class ServletRequestDataBinder extends WebDataBinder {
if (multipartRequest != null) { if (multipartRequest != null) {
bindMultipart(multipartRequest.getMultiFileMap(), mpvs); bindMultipart(multipartRequest.getMultiFileMap(), mpvs);
} }
else if (StringUtils.startsWithIgnoreCase(request.getContentType(), "multipart/")) { else if (StringUtils.startsWithIgnoreCase(request.getContentType(), MediaType.MULTIPART_FORM_DATA_VALUE)) {
HttpServletRequest httpServletRequest = WebUtils.getNativeRequest(request, HttpServletRequest.class); HttpServletRequest httpServletRequest = WebUtils.getNativeRequest(request, HttpServletRequest.class);
if (httpServletRequest != null) { if (httpServletRequest != null && HttpMethod.POST.matches(httpServletRequest.getMethod())) {
StandardServletPartUtils.bindParts(httpServletRequest, mpvs, isBindEmptyMultipartFiles()); StandardServletPartUtils.bindParts(httpServletRequest, mpvs, isBindEmptyMultipartFiles());
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,6 +19,9 @@ package org.springframework.web.bind.support;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.MutablePropertyValues;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.validation.BindException; import org.springframework.validation.BindException;
@ -107,13 +110,15 @@ public class WebRequestDataBinder extends WebDataBinder {
public void bind(WebRequest request) { public void bind(WebRequest request) {
MutablePropertyValues mpvs = new MutablePropertyValues(request.getParameterMap()); MutablePropertyValues mpvs = new MutablePropertyValues(request.getParameterMap());
if (request instanceof NativeWebRequest) { if (request instanceof NativeWebRequest) {
MultipartRequest multipartRequest = ((NativeWebRequest) request).getNativeRequest(MultipartRequest.class); NativeWebRequest nativeRequest = (NativeWebRequest) request;
MultipartRequest multipartRequest = nativeRequest.getNativeRequest(MultipartRequest.class);
if (multipartRequest != null) { if (multipartRequest != null) {
bindMultipart(multipartRequest.getMultiFileMap(), mpvs); bindMultipart(multipartRequest.getMultiFileMap(), mpvs);
} }
else if (StringUtils.startsWithIgnoreCase(request.getHeader("Content-Type"), "multipart/")) { else if (StringUtils.startsWithIgnoreCase(
HttpServletRequest servletRequest = ((NativeWebRequest) request).getNativeRequest(HttpServletRequest.class); request.getHeader(HttpHeaders.CONTENT_TYPE), MediaType.MULTIPART_FORM_DATA_VALUE)) {
if (servletRequest != null) { HttpServletRequest servletRequest = nativeRequest.getNativeRequest(HttpServletRequest.class);
if (servletRequest != null && HttpMethod.POST.matches(servletRequest.getMethod())) {
StandardServletPartUtils.bindParts(servletRequest, mpvs, isBindEmptyMultipartFiles()); StandardServletPartUtils.bindParts(servletRequest, mpvs, isBindEmptyMultipartFiles());
} }
} }

View File

@ -38,6 +38,9 @@ import org.springframework.beans.BeanInstantiationException;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.TypeMismatchException; import org.springframework.beans.TypeMismatchException;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
@ -349,9 +352,10 @@ public class ModelAttributeMethodProcessor implements HandlerMethodArgumentResol
return (files.size() == 1 ? files.get(0) : files); return (files.size() == 1 ? files.get(0) : files);
} }
} }
else if (StringUtils.startsWithIgnoreCase(request.getHeader("Content-Type"), "multipart/")) { else if (StringUtils.startsWithIgnoreCase(
request.getHeader(HttpHeaders.CONTENT_TYPE), MediaType.MULTIPART_FORM_DATA_VALUE)) {
HttpServletRequest servletRequest = request.getNativeRequest(HttpServletRequest.class); HttpServletRequest servletRequest = request.getNativeRequest(HttpServletRequest.class);
if (servletRequest != null) { if (servletRequest != null && HttpMethod.POST.matches(servletRequest.getMethod())) {
List<Part> parts = StandardServletPartUtils.getParts(servletRequest, paramName); List<Part> parts = StandardServletPartUtils.getParts(servletRequest, paramName);
if (!parts.isEmpty()) { if (!parts.isEmpty()) {
return (parts.size() == 1 ? parts.get(0) : parts); return (parts.size() == 1 ? parts.get(0) : parts);

View File

@ -1998,9 +1998,8 @@ public class ServletAnnotationControllerHandlerMethodTests extends AbstractServl
void dataClassBindingWithServletPart(boolean usePathPatterns) throws Exception { void dataClassBindingWithServletPart(boolean usePathPatterns) throws Exception {
initDispatcherServlet(ServletPartDataClassController.class, usePathPatterns); initDispatcherServlet(ServletPartDataClassController.class, usePathPatterns);
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/bind");
request.setContentType("multipart/form-data"); request.setContentType("multipart/form-data");
request.setRequestURI("/bind");
request.addPart(new MockPart("param1", "value1".getBytes(StandardCharsets.UTF_8))); request.addPart(new MockPart("param1", "value1".getBytes(StandardCharsets.UTF_8)));
request.addParameter("param2", "true"); request.addParameter("param2", "true");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();