diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index b3f20b01d0..041cf24f6e 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,7 @@ import org.springframework.web.context.request.async.CallableProcessingIntercept import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.context.request.async.DeferredResultProcessingInterceptor; import org.springframework.web.context.request.async.WebAsyncUtils; +import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest; import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.ModelAndView; @@ -67,6 +68,10 @@ final class TestDispatcherServlet extends DispatcherServlet { protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + if (!request.getParts().isEmpty()) { + request = new StandardMultipartHttpServletRequest(request); + } + registerAsyncResultInterceptors(request); super.service(request, response); @@ -80,8 +85,9 @@ final class TestDispatcherServlet extends DispatcherServlet { MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class); Assert.notNull(mockRequest, "Expected MockHttpServletRequest"); asyncContext = (MockAsyncContext) mockRequest.getAsyncContext(); + String requestClassName = request.getClass().getName(); Assert.notNull(asyncContext, () -> - "Outer request wrapper " + request.getClass().getName() + " has an AsyncContext," + + "Outer request wrapper " + requestClassName + " has an AsyncContext," + "but it is not a MockAsyncContext, while the nested " + mockRequest.getClass().getName() + " does not have an AsyncContext at all."); } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java index 51aae3c42a..dea33261c2 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,6 @@ import org.springframework.mock.web.MockMultipartHttpServletRequest; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest; /** * Default builder for {@link MockMultipartHttpServletRequest}. @@ -138,17 +137,9 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque */ @Override protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) { - MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext); - this.files.stream().forEach(request::addFile); + this.files.forEach(request::addFile); this.parts.values().stream().flatMap(Collection::stream).forEach(request::addPart); - - if (!this.parts.isEmpty()) { - new StandardMultipartHttpServletRequest(request) - .getMultiFileMap().values().stream().flatMap(Collection::stream) - .forEach(request::addFile); - } - return request; } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java index a1b4558a6e..a00dd2774e 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -230,7 +230,7 @@ public class MultipartControllerTests { MockPart filePart = new MockPart("file", "orig", fileContent); byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); - MockPart jsonPart = new MockPart("json", "json", json); + MockPart jsonPart = new MockPart("json", json); jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); standaloneSetup(new MultipartController()).build()