MockMvc applies StandardMultipartHttpServletRequest wrapper

This is necessary to correctly process multipart requests and resolve
@RequestPart arguments and MultipartFile arguments.

Closes gh-25602
This commit is contained in:
Rossen Stoyanchev 2020-08-19 09:01:12 +01:00
parent 50b20c2bb7
commit 443e9ee618
3 changed files with 12 additions and 15 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -35,6 +35,7 @@ import org.springframework.web.context.request.async.CallableProcessingIntercept
import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.DeferredResultProcessingInterceptor; import org.springframework.web.context.request.async.DeferredResultProcessingInterceptor;
import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest;
import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.ModelAndView;
@ -67,6 +68,10 @@ final class TestDispatcherServlet extends DispatcherServlet {
protected void service(HttpServletRequest request, HttpServletResponse response) protected void service(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException { throws ServletException, IOException {
if (!request.getParts().isEmpty()) {
request = new StandardMultipartHttpServletRequest(request);
}
registerAsyncResultInterceptors(request); registerAsyncResultInterceptors(request);
super.service(request, response); super.service(request, response);
@ -80,8 +85,9 @@ final class TestDispatcherServlet extends DispatcherServlet {
MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class); MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class);
Assert.notNull(mockRequest, "Expected MockHttpServletRequest"); Assert.notNull(mockRequest, "Expected MockHttpServletRequest");
asyncContext = (MockAsyncContext) mockRequest.getAsyncContext(); asyncContext = (MockAsyncContext) mockRequest.getAsyncContext();
String requestClassName = request.getClass().getName();
Assert.notNull(asyncContext, () -> Assert.notNull(asyncContext, () ->
"Outer request wrapper " + request.getClass().getName() + " has an AsyncContext," + "Outer request wrapper " + requestClassName + " has an AsyncContext," +
"but it is not a MockAsyncContext, while the nested " + "but it is not a MockAsyncContext, while the nested " +
mockRequest.getClass().getName() + " does not have an AsyncContext at all."); mockRequest.getClass().getName() + " does not have an AsyncContext at all.");
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -33,7 +33,6 @@ import org.springframework.mock.web.MockMultipartHttpServletRequest;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest;
/** /**
* Default builder for {@link MockMultipartHttpServletRequest}. * Default builder for {@link MockMultipartHttpServletRequest}.
@ -138,17 +137,9 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque
*/ */
@Override @Override
protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) { protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext); MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext);
this.files.stream().forEach(request::addFile); this.files.forEach(request::addFile);
this.parts.values().stream().flatMap(Collection::stream).forEach(request::addPart); this.parts.values().stream().flatMap(Collection::stream).forEach(request::addPart);
if (!this.parts.isEmpty()) {
new StandardMultipartHttpServletRequest(request)
.getMultiFileMap().values().stream().flatMap(Collection::stream)
.forEach(request::addFile);
}
return request; return request;
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -230,7 +230,7 @@ public class MultipartControllerTests {
MockPart filePart = new MockPart("file", "orig", fileContent); MockPart filePart = new MockPart("file", "orig", fileContent);
byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8);
MockPart jsonPart = new MockPart("json", "json", json); MockPart jsonPart = new MockPart("json", json);
jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON);
standaloneSetup(new MultipartController()).build() standaloneSetup(new MultipartController()).build()