MockMvc applies StandardMultipartHttpServletRequest wrapper
This is necessary to correctly process multipart requests and resolve @RequestPart arguments and MultipartFile arguments. Closes gh-25602
This commit is contained in:
		
							parent
							
								
									50b20c2bb7
								
							
						
					
					
						commit
						443e9ee618
					
				| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2002-2018 the original author or authors.
 | 
			
		||||
 * Copyright 2002-2020 the original author or authors.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
| 
						 | 
				
			
			@ -35,6 +35,7 @@ import org.springframework.web.context.request.async.CallableProcessingIntercept
 | 
			
		|||
import org.springframework.web.context.request.async.DeferredResult;
 | 
			
		||||
import org.springframework.web.context.request.async.DeferredResultProcessingInterceptor;
 | 
			
		||||
import org.springframework.web.context.request.async.WebAsyncUtils;
 | 
			
		||||
import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest;
 | 
			
		||||
import org.springframework.web.servlet.DispatcherServlet;
 | 
			
		||||
import org.springframework.web.servlet.HandlerExecutionChain;
 | 
			
		||||
import org.springframework.web.servlet.ModelAndView;
 | 
			
		||||
| 
						 | 
				
			
			@ -67,6 +68,10 @@ final class TestDispatcherServlet extends DispatcherServlet {
 | 
			
		|||
	protected void service(HttpServletRequest request, HttpServletResponse response)
 | 
			
		||||
			throws ServletException, IOException {
 | 
			
		||||
 | 
			
		||||
		if (!request.getParts().isEmpty()) {
 | 
			
		||||
			request = new StandardMultipartHttpServletRequest(request);
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		registerAsyncResultInterceptors(request);
 | 
			
		||||
 | 
			
		||||
		super.service(request, response);
 | 
			
		||||
| 
						 | 
				
			
			@ -80,8 +85,9 @@ final class TestDispatcherServlet extends DispatcherServlet {
 | 
			
		|||
				MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class);
 | 
			
		||||
				Assert.notNull(mockRequest, "Expected MockHttpServletRequest");
 | 
			
		||||
				asyncContext = (MockAsyncContext) mockRequest.getAsyncContext();
 | 
			
		||||
				String requestClassName = request.getClass().getName();
 | 
			
		||||
				Assert.notNull(asyncContext, () ->
 | 
			
		||||
						"Outer request wrapper " + request.getClass().getName() + " has an AsyncContext," +
 | 
			
		||||
						"Outer request wrapper " + requestClassName + " has an AsyncContext," +
 | 
			
		||||
								"but it is not a MockAsyncContext, while the nested " +
 | 
			
		||||
								mockRequest.getClass().getName() + " does not have an AsyncContext at all.");
 | 
			
		||||
			}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2002-2017 the original author or authors.
 | 
			
		||||
 * Copyright 2002-2020 the original author or authors.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
| 
						 | 
				
			
			@ -33,7 +33,6 @@ import org.springframework.mock.web.MockMultipartHttpServletRequest;
 | 
			
		|||
import org.springframework.util.Assert;
 | 
			
		||||
import org.springframework.util.LinkedMultiValueMap;
 | 
			
		||||
import org.springframework.util.MultiValueMap;
 | 
			
		||||
import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Default builder for {@link MockMultipartHttpServletRequest}.
 | 
			
		||||
| 
						 | 
				
			
			@ -138,17 +137,9 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque
 | 
			
		|||
	 */
 | 
			
		||||
	@Override
 | 
			
		||||
	protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) {
 | 
			
		||||
 | 
			
		||||
		MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext);
 | 
			
		||||
		this.files.stream().forEach(request::addFile);
 | 
			
		||||
		this.files.forEach(request::addFile);
 | 
			
		||||
		this.parts.values().stream().flatMap(Collection::stream).forEach(request::addPart);
 | 
			
		||||
 | 
			
		||||
		if (!this.parts.isEmpty()) {
 | 
			
		||||
			new StandardMultipartHttpServletRequest(request)
 | 
			
		||||
					.getMultiFileMap().values().stream().flatMap(Collection::stream)
 | 
			
		||||
					.forEach(request::addFile);
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return request;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2002-2019 the original author or authors.
 | 
			
		||||
 * Copyright 2002-2020 the original author or authors.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
| 
						 | 
				
			
			@ -230,7 +230,7 @@ public class MultipartControllerTests {
 | 
			
		|||
		MockPart filePart = new MockPart("file", "orig", fileContent);
 | 
			
		||||
 | 
			
		||||
		byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8);
 | 
			
		||||
		MockPart jsonPart = new MockPart("json", "json", json);
 | 
			
		||||
		MockPart jsonPart = new MockPart("json", json);
 | 
			
		||||
		jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON);
 | 
			
		||||
 | 
			
		||||
		standaloneSetup(new MultipartController()).build()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue