diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java b/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java index 736219b8a27..a1ce85d696a 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java @@ -31,10 +31,12 @@ import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.lang.Nullable; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockMultipartHttpServletRequest; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.RequestBuilder; import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder; +import org.springframework.test.web.servlet.request.AbstractMockMultipartHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.test.web.servlet.setup.DefaultMockMvcBuilder; import org.springframework.test.web.servlet.setup.MockMvcBuilders; @@ -389,8 +391,42 @@ public final class MockMvcTester { public final class MockMvcRequestBuilder extends AbstractMockHttpServletRequestBuilder implements AssertProvider { + private final HttpMethod httpMethod; + private MockMvcRequestBuilder(HttpMethod httpMethod) { super(httpMethod); + this.httpMethod = httpMethod; + } + + /** + * Enable file upload support using multipart. + * @return a {@link MockMultipartMvcRequestBuilder} with the settings + * configured thus far + */ + public MockMultipartMvcRequestBuilder multipart() { + return new MockMultipartMvcRequestBuilder(this); + } + + public MvcTestResult exchange() { + return perform(this); + } + + @Override + public MvcTestResultAssert assertThat() { + return new MvcTestResultAssert(exchange(), MockMvcTester.this.jsonMessageConverter); + } + } + + /** + * A builder for {@link MockMultipartHttpServletRequest} that supports AssertJ. + */ + public final class MockMultipartMvcRequestBuilder + extends AbstractMockMultipartHttpServletRequestBuilder + implements AssertProvider { + + private MockMultipartMvcRequestBuilder(MockMvcRequestBuilder currentBuilder) { + super(currentBuilder.httpMethod); + merge(currentBuilder); } public MvcTestResult exchange() { diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockMultipartHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockMultipartHttpServletRequestBuilder.java new file mode 100644 index 00000000000..c766b1c4c06 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockMultipartHttpServletRequestBuilder.java @@ -0,0 +1,159 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.request; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import jakarta.servlet.ServletContext; +import jakarta.servlet.http.Part; + +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockMultipartFile; +import org.springframework.mock.web.MockMultipartHttpServletRequest; +import org.springframework.util.Assert; +import org.springframework.util.FileCopyUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * Base builder for {@link MockMultipartHttpServletRequest}. + * + * @author Rossen Stoyanchev + * @author Arjen Poutsma + * @author Stephane Nicoll + * @since 6.2 + * @param a self reference to the builder type + */ +public abstract class AbstractMockMultipartHttpServletRequestBuilder> + extends AbstractMockHttpServletRequestBuilder { + + private final List files = new ArrayList<>(); + + private final MultiValueMap parts = new LinkedMultiValueMap<>(); + + + protected AbstractMockMultipartHttpServletRequestBuilder(HttpMethod httpMethod) { + super(httpMethod); + } + + /** + * Add a new {@link MockMultipartFile} with the given content. + * @param name the name of the file + * @param content the content of the file + */ + public B file(String name, byte[] content) { + this.files.add(new MockMultipartFile(name, content)); + return self(); + } + + /** + * Add the given {@link MockMultipartFile}. + * @param file the multipart file + */ + public B file(MockMultipartFile file) { + this.files.add(file); + return self(); + } + + /** + * Add {@link Part} components to the request. + * @param parts one or more parts to add + * @since 5.0 + */ + public B part(Part... parts) { + Assert.notEmpty(parts, "'parts' must not be empty"); + for (Part part : parts) { + this.parts.add(part.getName(), part); + } + return self(); + } + + @Override + public Object merge(@Nullable Object parent) { + if (parent == null) { + return this; + } + if (parent instanceof AbstractMockHttpServletRequestBuilder) { + super.merge(parent); + if (parent instanceof AbstractMockMultipartHttpServletRequestBuilder parentBuilder) { + this.files.addAll(parentBuilder.files); + parentBuilder.parts.keySet().forEach(name -> + this.parts.putIfAbsent(name, parentBuilder.parts.get(name))); + } + } + else { + throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); + } + return this; + } + + /** + * Create a new {@link MockMultipartHttpServletRequest} based on the + * supplied {@code ServletContext} and the {@code MockMultipartFiles} + * added to this builder. + */ + @Override + protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) { + MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext); + Charset defaultCharset = (request.getCharacterEncoding() != null ? + Charset.forName(request.getCharacterEncoding()) : StandardCharsets.UTF_8); + + this.files.forEach(request::addFile); + this.parts.values().stream().flatMap(Collection::stream).forEach(part -> { + request.addPart(part); + try { + String name = part.getName(); + String filename = part.getSubmittedFileName(); + InputStream is = part.getInputStream(); + if (filename != null) { + request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is)); + } + else { + InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset)); + String value = FileCopyUtils.copyToString(reader); + request.addParameter(part.getName(), value); + } + } + catch (IOException ex) { + throw new IllegalStateException("Failed to read content for part " + part.getName(), ex); + } + }); + + return request; + } + + private Charset getCharsetOrDefault(Part part, Charset defaultCharset) { + if (part.getContentType() != null) { + MediaType mediaType = MediaType.parseMediaType(part.getContentType()); + if (mediaType.getCharset() != null) { + return mediaType.getCharset(); + } + } + return defaultCharset; + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java index 36a4a548c3e..e043deda853 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java @@ -16,42 +16,22 @@ package org.springframework.test.web.servlet.request; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; import java.net.URI; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; - -import jakarta.servlet.ServletContext; -import jakarta.servlet.http.Part; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; -import org.springframework.lang.Nullable; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockMultipartFile; import org.springframework.mock.web.MockMultipartHttpServletRequest; -import org.springframework.util.Assert; -import org.springframework.util.FileCopyUtils; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; /** * Default builder for {@link MockMultipartHttpServletRequest}. * * @author Rossen Stoyanchev * @author Arjen Poutsma + * @author Stephane Nicoll * @since 3.2 */ -public class MockMultipartHttpServletRequestBuilder extends AbstractMockHttpServletRequestBuilder { - - private final List files = new ArrayList<>(); - - private final MultiValueMap parts = new LinkedMultiValueMap<>(); +public class MockMultipartHttpServletRequestBuilder + extends AbstractMockMultipartHttpServletRequestBuilder { /** @@ -98,101 +78,4 @@ public class MockMultipartHttpServletRequestBuilder extends AbstractMockHttpServ super.contentType(MediaType.MULTIPART_FORM_DATA); } - - /** - * Add a new {@link MockMultipartFile} with the given content. - * @param name the name of the file - * @param content the content of the file - */ - public MockMultipartHttpServletRequestBuilder file(String name, byte[] content) { - this.files.add(new MockMultipartFile(name, content)); - return this; - } - - /** - * Add the given {@link MockMultipartFile}. - * @param file the multipart file - */ - public MockMultipartHttpServletRequestBuilder file(MockMultipartFile file) { - this.files.add(file); - return this; - } - - /** - * Add {@link Part} components to the request. - * @param parts one or more parts to add - * @since 5.0 - */ - public MockMultipartHttpServletRequestBuilder part(Part... parts) { - Assert.notEmpty(parts, "'parts' must not be empty"); - for (Part part : parts) { - this.parts.add(part.getName(), part); - } - return this; - } - - @Override - public Object merge(@Nullable Object parent) { - if (parent == null) { - return this; - } - if (parent instanceof AbstractMockHttpServletRequestBuilder) { - super.merge(parent); - if (parent instanceof MockMultipartHttpServletRequestBuilder parentBuilder) { - this.files.addAll(parentBuilder.files); - parentBuilder.parts.keySet().forEach(name -> - this.parts.putIfAbsent(name, parentBuilder.parts.get(name))); - } - } - else { - throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); - } - return this; - } - - /** - * Create a new {@link MockMultipartHttpServletRequest} based on the - * supplied {@code ServletContext} and the {@code MockMultipartFiles} - * added to this builder. - */ - @Override - protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) { - MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext); - Charset defaultCharset = (request.getCharacterEncoding() != null ? - Charset.forName(request.getCharacterEncoding()) : StandardCharsets.UTF_8); - - this.files.forEach(request::addFile); - this.parts.values().stream().flatMap(Collection::stream).forEach(part -> { - request.addPart(part); - try { - String name = part.getName(); - String filename = part.getSubmittedFileName(); - InputStream is = part.getInputStream(); - if (filename != null) { - request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is)); - } - else { - InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset)); - String value = FileCopyUtils.copyToString(reader); - request.addParameter(part.getName(), value); - } - } - catch (IOException ex) { - throw new IllegalStateException("Failed to read content for part " + part.getName(), ex); - } - }); - - return request; - } - - private Charset getCharsetOrDefault(Part part, Charset defaultCharset) { - if (part.getContentType() != null) { - MediaType mediaType = MediaType.parseMediaType(part.getContentType()); - if (mediaType.getCharset() != null) { - return mediaType.getCharset(); - } - } - return defaultCharset; - } - } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java index ae1b2c99287..2c2951d5ef1 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java @@ -20,6 +20,7 @@ import java.io.ByteArrayOutputStream; import java.io.PrintStream; import java.nio.charset.StandardCharsets; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -31,6 +32,7 @@ import jakarta.servlet.ServletException; import jakarta.servlet.http.Cookie; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import jakarta.servlet.http.Part; import jakarta.validation.Valid; import jakarta.validation.constraints.Size; import org.junit.jupiter.api.AfterEach; @@ -44,6 +46,8 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.mock.web.MockMultipartFile; +import org.springframework.mock.web.MockPart; import org.springframework.stereotype.Controller; import org.springframework.test.context.junit.jupiter.web.SpringJUnitWebConfig; import org.springframework.test.web.Person; @@ -54,13 +58,19 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.PutMapping; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RequestPart; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.SessionAttributes; import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.support.MissingServletRequestPartException; import org.springframework.web.server.ResponseStatusException; import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.HandlerInterceptor; +import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.mvc.support.RedirectAttributes; @@ -113,6 +123,41 @@ public class MockMvcTesterIntegrationTests { } } + @Nested + class MultipartTests { + + private final MockMultipartFile JSON_PART_FILE = new MockMultipartFile("json", "json", "application/json", """ + { + "name": "test" + }""".getBytes(StandardCharsets.UTF_8)); + + @Test + void multipartWithPut() { + MockMultipartFile part = new MockMultipartFile("file", "content.txt", null, "value".getBytes(StandardCharsets.UTF_8)); + assertThat(mvc.put().uri("/multipart-put").multipart().file(part).file(JSON_PART_FILE)) + .hasStatusOk() + .hasViewName("index") + .model().contains(entry("name", "file")); + } + + @Test + void multipartWithMissingPart() { + assertThat(mvc.put().uri("/multipart-put").multipart().file(JSON_PART_FILE)) + .hasStatus(HttpStatus.BAD_REQUEST) + .failure().isInstanceOfSatisfying(MissingServletRequestPartException.class, + ex -> assertThat(ex.getRequestPartName()).isEqualTo("file")); + } + + @Test + void multipartWithNamedPart() { + MockPart part = new MockPart("part", "content.txt", "value".getBytes(StandardCharsets.UTF_8)); + assertThat(mvc.post().uri("/part").multipart().part(part).file(JSON_PART_FILE)) + .hasStatusOk() + .hasViewName("index") + .model().contains(entry("part", "content.txt"), entry("name", "test")); + } + } + @Nested class CookieTests { @@ -516,7 +561,7 @@ public class MockMvcTesterIntegrationTests { @Configuration @EnableWebMvc @Import({ TestController.class, PersonController.class, AsyncController.class, - SessionController.class, ErrorController.class }) + MultipartController.class, SessionController.class, ErrorController.class }) static class WebConfiguration { } @@ -564,6 +609,22 @@ public class MockMvcTesterIntegrationTests { } } + @Controller + static class MultipartController { + + @PostMapping("/part") + ModelAndView part(@RequestPart Part part, @RequestPart Map json) { + Map model = new HashMap<>(json); + model.put(part.getName(), part.getSubmittedFileName()); + return new ModelAndView("index", model); + } + + @PutMapping("/multipart-put") + public ModelAndView multiPartViaHttpPut(@RequestParam MultipartFile file) { + return new ModelAndView("index", Map.of("name", file.getName())); + } + } + @Controller @SessionAttributes("locale") static class SessionController {