Unwrap if necessary for MultipartHttpServletRequest

Before this commit RequestPartServletServerHttpRequest simply did an
instanceof check for MultipartHttpServletRequest. That hasn't failed
because request wrapping typically happens in filters before the
DispatcherServlet calls the MultipartResolver.

With Spring MVC Test and the Spring Security integraiton however,
this order is reversed since there we prepare the multipart request
upfront, i.e. there is no actual parsing.

The commit unwraps the request if necessary.

Issue: SPR-13317
This commit is contained in:
Rossen Stoyanchev 2015-08-21 10:28:08 -04:00
parent 22948bd7f0
commit 473dd5e9e8
3 changed files with 96 additions and 16 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,44 +17,102 @@
package org.springframework.test.web.servlet.samples.standalone;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.Collections;
import java.util.Map;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.junit.Test;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.stereotype.Controller;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.multipart.MultipartFile;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;
import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.fileUpload;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.model;
import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup;
/**
* @author Rossen Stoyanchev
*/
public class FileUploadControllerTests {
private static final Charset CHARSET = Charset.forName("UTF-8");
@Test
public void readString() throws Exception {
MockMultipartFile file = new MockMultipartFile("file", "orig", null, "bar".getBytes());
standaloneSetup(new FileUploadController()).build()
.perform(fileUpload("/fileupload").file(file))
.andExpect(model().attribute("message", "File 'orig' uploaded successfully"));
public void multipartRequest() throws Exception {
byte[] fileContent = "bar".getBytes(CHARSET);
MockMultipartFile filePart = new MockMultipartFile("file", "orig", null, fileContent);
byte[] json = "{\"name\":\"yeeeah\"}".getBytes(CHARSET);
MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json);
MockMvc mockMvc = standaloneSetup(new MultipartController()).build();
mockMvc.perform(fileUpload("/test").file(filePart).file(jsonPart))
.andExpect(model().attribute("fileContent", fileContent))
.andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah")));
}
// SPR-13317
@Test
public void multipartRequestWrapped() throws Exception {
byte[] json = "{\"name\":\"yeeeah\"}".getBytes(CHARSET);
MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json);
Filter filter = new RequestWrappingFilter();
MockMvc mockMvc = standaloneSetup(new MultipartController()).addFilter(filter).build();
Map<String, String> jsonMap = Collections.singletonMap("name", "yeeeah");
mockMvc.perform(fileUpload("/testJson").file(jsonPart)).andExpect(model().attribute("json", jsonMap));
}
@SuppressWarnings("unused")
@Controller
private static class FileUploadController {
private static class MultipartController {
@RequestMapping(value = "/test", method = RequestMethod.POST)
public String processMultipart(@RequestParam MultipartFile file,
@RequestPart Map<String, String> json, Model model) throws IOException {
model.addAttribute("jsonContent", json);
model.addAttribute("fileContent", file.getBytes());
@RequestMapping(value="/fileupload", method=RequestMethod.POST)
public String processUpload(@RequestParam MultipartFile file, Model model) throws IOException {
model.addAttribute("message", "File '" + file.getOriginalFilename() + "' uploaded successfully");
return "redirect:/index";
}
@RequestMapping(value = "/testJson", method = RequestMethod.POST)
public String processMultipart(@RequestPart Map<String, String> json, Model model) {
model.addAttribute("json", json);
return "redirect:/index";
}
}
}
private static class RequestWrappingFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws IOException, ServletException {
request = new HttpServletRequestWrapper(request);
filterChain.doFilter(request, response);
}
}
}

View File

@ -31,6 +31,8 @@ import org.springframework.web.multipart.MultipartException;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;
import org.springframework.web.multipart.MultipartResolver;
import org.springframework.web.util.WebUtils;
/**
* {@link ServerHttpRequest} implementation that accesses one part of a multipart
@ -81,8 +83,9 @@ public class RequestPartServletServerHttpRequest extends ServletServerHttpReques
}
private static MultipartHttpServletRequest asMultipartRequest(HttpServletRequest request) {
if (request instanceof MultipartHttpServletRequest) {
return (MultipartHttpServletRequest) request;
MultipartHttpServletRequest unwrapped = WebUtils.getNativeRequest(request, MultipartHttpServletRequest.class);
if (unwrapped != null) {
return unwrapped;
}
else if (ClassUtils.hasMethod(HttpServletRequest.class, "getParts")) {
// Servlet 3.0 available ..
@ -91,11 +94,13 @@ public class RequestPartServletServerHttpRequest extends ServletServerHttpReques
throw new IllegalArgumentException("Expected MultipartHttpServletRequest: is a MultipartResolver configured?");
}
@Override
public HttpHeaders getHeaders() {
return this.headers;
}
@Override
public InputStream getBody() throws IOException {
if (this.multipartRequest instanceof StandardMultipartHttpServletRequest) {

View File

@ -19,6 +19,9 @@ package org.springframework.web.multipart.support;
import java.net.URI;
import java.nio.charset.Charset;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
@ -86,6 +89,20 @@ public class RequestPartServletServerHttpRequestTests {
assertArrayEquals(bytes, result);
}
// SPR-13317
@Test
public void getBodyWithWrappedRequest() throws Exception {
byte[] bytes = "content".getBytes("UTF-8");
MultipartFile part = new MockMultipartFile("part", "", "application/json", bytes);
this.mockRequest.addFile(part);
HttpServletRequest wrapped = new HttpServletRequestWrapper(this.mockRequest);
ServerHttpRequest request = new RequestPartServletServerHttpRequest(wrapped, "part");
byte[] result = FileCopyUtils.copyToByteArray(request.getBody());
assertArrayEquals(bytes, result);
}
// SPR-13096
@Test