diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FileUploadControllerTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FileUploadControllerTests.java index 40407f72c13..094c1799b9f 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FileUploadControllerTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FileUploadControllerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,44 +17,102 @@ package org.springframework.test.web.servlet.samples.standalone; import java.io.IOException; +import java.nio.charset.Charset; +import java.util.Collections; +import java.util.Map; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; import org.junit.Test; import org.springframework.mock.web.MockMultipartFile; import org.springframework.stereotype.Controller; +import org.springframework.test.web.servlet.MockMvc; import org.springframework.ui.Model; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RequestPart; +import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.multipart.MultipartFile; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; -import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.fileUpload; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.model; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup; /** * @author Rossen Stoyanchev */ public class FileUploadControllerTests { + private static final Charset CHARSET = Charset.forName("UTF-8"); + + @Test - public void readString() throws Exception { - MockMultipartFile file = new MockMultipartFile("file", "orig", null, "bar".getBytes()); - standaloneSetup(new FileUploadController()).build() - .perform(fileUpload("/fileupload").file(file)) - .andExpect(model().attribute("message", "File 'orig' uploaded successfully")); + public void multipartRequest() throws Exception { + + byte[] fileContent = "bar".getBytes(CHARSET); + MockMultipartFile filePart = new MockMultipartFile("file", "orig", null, fileContent); + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(CHARSET); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + MockMvc mockMvc = standaloneSetup(new MultipartController()).build(); + mockMvc.perform(fileUpload("/test").file(filePart).file(jsonPart)) + .andExpect(model().attribute("fileContent", fileContent)) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + // SPR-13317 + + @Test + public void multipartRequestWrapped() throws Exception { + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(CHARSET); + MockMultipartFile jsonPart = new MockMultipartFile("json", "json", "application/json", json); + + Filter filter = new RequestWrappingFilter(); + MockMvc mockMvc = standaloneSetup(new MultipartController()).addFilter(filter).build(); + + Map jsonMap = Collections.singletonMap("name", "yeeeah"); + mockMvc.perform(fileUpload("/testJson").file(jsonPart)).andExpect(model().attribute("json", jsonMap)); } + @SuppressWarnings("unused") @Controller - private static class FileUploadController { + private static class MultipartController { + + @RequestMapping(value = "/test", method = RequestMethod.POST) + public String processMultipart(@RequestParam MultipartFile file, + @RequestPart Map json, Model model) throws IOException { + + model.addAttribute("jsonContent", json); + model.addAttribute("fileContent", file.getBytes()); - @RequestMapping(value="/fileupload", method=RequestMethod.POST) - public String processUpload(@RequestParam MultipartFile file, Model model) throws IOException { - model.addAttribute("message", "File '" + file.getOriginalFilename() + "' uploaded successfully"); return "redirect:/index"; } + @RequestMapping(value = "/testJson", method = RequestMethod.POST) + public String processMultipart(@RequestPart Map json, Model model) { + model.addAttribute("json", json); + return "redirect:/index"; + } } -} + private static class RequestWrappingFilter extends OncePerRequestFilter { + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws IOException, ServletException { + + request = new HttpServletRequestWrapper(request); + filterChain.doFilter(request, response); + } + } + +} \ No newline at end of file diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java index 7e89d7c08fd..b26637e632f 100644 --- a/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java @@ -31,6 +31,8 @@ import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartHttpServletRequest; import org.springframework.web.multipart.MultipartResolver; +import org.springframework.web.util.WebUtils; + /** * {@link ServerHttpRequest} implementation that accesses one part of a multipart @@ -81,8 +83,9 @@ public class RequestPartServletServerHttpRequest extends ServletServerHttpReques } private static MultipartHttpServletRequest asMultipartRequest(HttpServletRequest request) { - if (request instanceof MultipartHttpServletRequest) { - return (MultipartHttpServletRequest) request; + MultipartHttpServletRequest unwrapped = WebUtils.getNativeRequest(request, MultipartHttpServletRequest.class); + if (unwrapped != null) { + return unwrapped; } else if (ClassUtils.hasMethod(HttpServletRequest.class, "getParts")) { // Servlet 3.0 available .. @@ -91,11 +94,13 @@ public class RequestPartServletServerHttpRequest extends ServletServerHttpReques throw new IllegalArgumentException("Expected MultipartHttpServletRequest: is a MultipartResolver configured?"); } + @Override public HttpHeaders getHeaders() { return this.headers; } + @Override public InputStream getBody() throws IOException { if (this.multipartRequest instanceof StandardMultipartHttpServletRequest) { diff --git a/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java index f9bff37de01..f821dab48c1 100644 --- a/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java +++ b/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java @@ -19,6 +19,9 @@ package org.springframework.web.multipart.support; import java.net.URI; import java.nio.charset.Charset; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; + import org.junit.Test; import org.springframework.http.HttpHeaders; @@ -86,6 +89,20 @@ public class RequestPartServletServerHttpRequestTests { assertArrayEquals(bytes, result); } + // SPR-13317 + + @Test + public void getBodyWithWrappedRequest() throws Exception { + byte[] bytes = "content".getBytes("UTF-8"); + MultipartFile part = new MockMultipartFile("part", "", "application/json", bytes); + this.mockRequest.addFile(part); + HttpServletRequest wrapped = new HttpServletRequestWrapper(this.mockRequest); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(wrapped, "part"); + + byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); + assertArrayEquals(bytes, result); + } + // SPR-13096 @Test