Add multipart support for MockMvcTester

File uploads with MockMvc require a separate
MockHttpServletRequestBuilder implementation. This commit applies the
same change to support AssertJ on this builder, but for the multipart
version.

Any request builder can now use `multipart()` to "down cast" to a
dedicated multipart request builder that contains the settings
configured thus far.

Closes gh-33027
This commit is contained in:
Stéphane Nicoll 2024-06-17 14:30:00 +02:00
parent f2137c99e5
commit d76f37c90b
4 changed files with 260 additions and 121 deletions

View File

@ -31,10 +31,12 @@ import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockMultipartHttpServletRequest;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.RequestBuilder;
import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder;
import org.springframework.test.web.servlet.request.AbstractMockMultipartHttpServletRequestBuilder;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.setup.DefaultMockMvcBuilder;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
@ -389,8 +391,42 @@ public final class MockMvcTester {
public final class MockMvcRequestBuilder extends AbstractMockHttpServletRequestBuilder<MockMvcRequestBuilder>
implements AssertProvider<MvcTestResultAssert> {
private final HttpMethod httpMethod;
private MockMvcRequestBuilder(HttpMethod httpMethod) {
super(httpMethod);
this.httpMethod = httpMethod;
}
/**
* Enable file upload support using multipart.
* @return a {@link MockMultipartMvcRequestBuilder} with the settings
* configured thus far
*/
public MockMultipartMvcRequestBuilder multipart() {
return new MockMultipartMvcRequestBuilder(this);
}
public MvcTestResult exchange() {
return perform(this);
}
@Override
public MvcTestResultAssert assertThat() {
return new MvcTestResultAssert(exchange(), MockMvcTester.this.jsonMessageConverter);
}
}
/**
* A builder for {@link MockMultipartHttpServletRequest} that supports AssertJ.
*/
public final class MockMultipartMvcRequestBuilder
extends AbstractMockMultipartHttpServletRequestBuilder<MockMultipartMvcRequestBuilder>
implements AssertProvider<MvcTestResultAssert> {
private MockMultipartMvcRequestBuilder(MockMvcRequestBuilder currentBuilder) {
super(currentBuilder.httpMethod);
merge(currentBuilder);
}
public MvcTestResult exchange() {

View File

@ -0,0 +1,159 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.test.web.servlet.request;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import jakarta.servlet.ServletContext;
import jakarta.servlet.http.Part;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.mock.web.MockMultipartHttpServletRequest;
import org.springframework.util.Assert;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* Base builder for {@link MockMultipartHttpServletRequest}.
*
* @author Rossen Stoyanchev
* @author Arjen Poutsma
* @author Stephane Nicoll
* @since 6.2
* @param <B> a self reference to the builder type
*/
public abstract class AbstractMockMultipartHttpServletRequestBuilder<B extends AbstractMockMultipartHttpServletRequestBuilder<B>>
extends AbstractMockHttpServletRequestBuilder<B> {
private final List<MockMultipartFile> files = new ArrayList<>();
private final MultiValueMap<String, Part> parts = new LinkedMultiValueMap<>();
protected AbstractMockMultipartHttpServletRequestBuilder(HttpMethod httpMethod) {
super(httpMethod);
}
/**
* Add a new {@link MockMultipartFile} with the given content.
* @param name the name of the file
* @param content the content of the file
*/
public B file(String name, byte[] content) {
this.files.add(new MockMultipartFile(name, content));
return self();
}
/**
* Add the given {@link MockMultipartFile}.
* @param file the multipart file
*/
public B file(MockMultipartFile file) {
this.files.add(file);
return self();
}
/**
* Add {@link Part} components to the request.
* @param parts one or more parts to add
* @since 5.0
*/
public B part(Part... parts) {
Assert.notEmpty(parts, "'parts' must not be empty");
for (Part part : parts) {
this.parts.add(part.getName(), part);
}
return self();
}
@Override
public Object merge(@Nullable Object parent) {
if (parent == null) {
return this;
}
if (parent instanceof AbstractMockHttpServletRequestBuilder<?>) {
super.merge(parent);
if (parent instanceof AbstractMockMultipartHttpServletRequestBuilder<?> parentBuilder) {
this.files.addAll(parentBuilder.files);
parentBuilder.parts.keySet().forEach(name ->
this.parts.putIfAbsent(name, parentBuilder.parts.get(name)));
}
}
else {
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
return this;
}
/**
* Create a new {@link MockMultipartHttpServletRequest} based on the
* supplied {@code ServletContext} and the {@code MockMultipartFiles}
* added to this builder.
*/
@Override
protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext);
Charset defaultCharset = (request.getCharacterEncoding() != null ?
Charset.forName(request.getCharacterEncoding()) : StandardCharsets.UTF_8);
this.files.forEach(request::addFile);
this.parts.values().stream().flatMap(Collection::stream).forEach(part -> {
request.addPart(part);
try {
String name = part.getName();
String filename = part.getSubmittedFileName();
InputStream is = part.getInputStream();
if (filename != null) {
request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is));
}
else {
InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset));
String value = FileCopyUtils.copyToString(reader);
request.addParameter(part.getName(), value);
}
}
catch (IOException ex) {
throw new IllegalStateException("Failed to read content for part " + part.getName(), ex);
}
});
return request;
}
private Charset getCharsetOrDefault(Part part, Charset defaultCharset) {
if (part.getContentType() != null) {
MediaType mediaType = MediaType.parseMediaType(part.getContentType());
if (mediaType.getCharset() != null) {
return mediaType.getCharset();
}
}
return defaultCharset;
}
}

View File

@ -16,42 +16,22 @@
package org.springframework.test.web.servlet.request;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import jakarta.servlet.ServletContext;
import jakarta.servlet.http.Part;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.mock.web.MockMultipartHttpServletRequest;
import org.springframework.util.Assert;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* Default builder for {@link MockMultipartHttpServletRequest}.
*
* @author Rossen Stoyanchev
* @author Arjen Poutsma
* @author Stephane Nicoll
* @since 3.2
*/
public class MockMultipartHttpServletRequestBuilder extends AbstractMockHttpServletRequestBuilder<MockMultipartHttpServletRequestBuilder> {
private final List<MockMultipartFile> files = new ArrayList<>();
private final MultiValueMap<String, Part> parts = new LinkedMultiValueMap<>();
public class MockMultipartHttpServletRequestBuilder
extends AbstractMockMultipartHttpServletRequestBuilder<MockMultipartHttpServletRequestBuilder> {
/**
@ -98,101 +78,4 @@ public class MockMultipartHttpServletRequestBuilder extends AbstractMockHttpServ
super.contentType(MediaType.MULTIPART_FORM_DATA);
}
/**
* Add a new {@link MockMultipartFile} with the given content.
* @param name the name of the file
* @param content the content of the file
*/
public MockMultipartHttpServletRequestBuilder file(String name, byte[] content) {
this.files.add(new MockMultipartFile(name, content));
return this;
}
/**
* Add the given {@link MockMultipartFile}.
* @param file the multipart file
*/
public MockMultipartHttpServletRequestBuilder file(MockMultipartFile file) {
this.files.add(file);
return this;
}
/**
* Add {@link Part} components to the request.
* @param parts one or more parts to add
* @since 5.0
*/
public MockMultipartHttpServletRequestBuilder part(Part... parts) {
Assert.notEmpty(parts, "'parts' must not be empty");
for (Part part : parts) {
this.parts.add(part.getName(), part);
}
return this;
}
@Override
public Object merge(@Nullable Object parent) {
if (parent == null) {
return this;
}
if (parent instanceof AbstractMockHttpServletRequestBuilder) {
super.merge(parent);
if (parent instanceof MockMultipartHttpServletRequestBuilder parentBuilder) {
this.files.addAll(parentBuilder.files);
parentBuilder.parts.keySet().forEach(name ->
this.parts.putIfAbsent(name, parentBuilder.parts.get(name)));
}
}
else {
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
return this;
}
/**
* Create a new {@link MockMultipartHttpServletRequest} based on the
* supplied {@code ServletContext} and the {@code MockMultipartFiles}
* added to this builder.
*/
@Override
protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext);
Charset defaultCharset = (request.getCharacterEncoding() != null ?
Charset.forName(request.getCharacterEncoding()) : StandardCharsets.UTF_8);
this.files.forEach(request::addFile);
this.parts.values().stream().flatMap(Collection::stream).forEach(part -> {
request.addPart(part);
try {
String name = part.getName();
String filename = part.getSubmittedFileName();
InputStream is = part.getInputStream();
if (filename != null) {
request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is));
}
else {
InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset));
String value = FileCopyUtils.copyToString(reader);
request.addParameter(part.getName(), value);
}
}
catch (IOException ex) {
throw new IllegalStateException("Failed to read content for part " + part.getName(), ex);
}
});
return request;
}
private Charset getCharsetOrDefault(Part part, Charset defaultCharset) {
if (part.getContentType() != null) {
MediaType mediaType = MediaType.parseMediaType(part.getContentType());
if (mediaType.getCharset() != null) {
return mediaType.getCharset();
}
}
return defaultCharset;
}
}

View File

@ -20,6 +20,7 @@ import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@ -31,6 +32,7 @@ import jakarta.servlet.ServletException;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.Part;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Size;
import org.junit.jupiter.api.AfterEach;
@ -44,6 +46,8 @@ import org.springframework.core.io.ClassPathResource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.mock.web.MockPart;
import org.springframework.stereotype.Controller;
import org.springframework.test.context.junit.jupiter.web.SpringJUnitWebConfig;
import org.springframework.test.web.Person;
@ -54,13 +58,19 @@ import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.SessionAttributes;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.support.MissingServletRequestPartException;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.mvc.support.RedirectAttributes;
@ -113,6 +123,41 @@ public class MockMvcTesterIntegrationTests {
}
}
@Nested
class MultipartTests {
private final MockMultipartFile JSON_PART_FILE = new MockMultipartFile("json", "json", "application/json", """
{
"name": "test"
}""".getBytes(StandardCharsets.UTF_8));
@Test
void multipartWithPut() {
MockMultipartFile part = new MockMultipartFile("file", "content.txt", null, "value".getBytes(StandardCharsets.UTF_8));
assertThat(mvc.put().uri("/multipart-put").multipart().file(part).file(JSON_PART_FILE))
.hasStatusOk()
.hasViewName("index")
.model().contains(entry("name", "file"));
}
@Test
void multipartWithMissingPart() {
assertThat(mvc.put().uri("/multipart-put").multipart().file(JSON_PART_FILE))
.hasStatus(HttpStatus.BAD_REQUEST)
.failure().isInstanceOfSatisfying(MissingServletRequestPartException.class,
ex -> assertThat(ex.getRequestPartName()).isEqualTo("file"));
}
@Test
void multipartWithNamedPart() {
MockPart part = new MockPart("part", "content.txt", "value".getBytes(StandardCharsets.UTF_8));
assertThat(mvc.post().uri("/part").multipart().part(part).file(JSON_PART_FILE))
.hasStatusOk()
.hasViewName("index")
.model().contains(entry("part", "content.txt"), entry("name", "test"));
}
}
@Nested
class CookieTests {
@ -516,7 +561,7 @@ public class MockMvcTesterIntegrationTests {
@Configuration
@EnableWebMvc
@Import({ TestController.class, PersonController.class, AsyncController.class,
SessionController.class, ErrorController.class })
MultipartController.class, SessionController.class, ErrorController.class })
static class WebConfiguration {
}
@ -564,6 +609,22 @@ public class MockMvcTesterIntegrationTests {
}
}
@Controller
static class MultipartController {
@PostMapping("/part")
ModelAndView part(@RequestPart Part part, @RequestPart Map<String, String> json) {
Map<String, Object> model = new HashMap<>(json);
model.put(part.getName(), part.getSubmittedFileName());
return new ModelAndView("index", model);
}
@PutMapping("/multipart-put")
public ModelAndView multiPartViaHttpPut(@RequestParam MultipartFile file) {
return new ModelAndView("index", Map.of("name", file.getName()));
}
}
@Controller
@SessionAttributes("locale")
static class SessionController {