Add multipart support for MockMvcTester
File uploads with MockMvc require a separate MockHttpServletRequestBuilder implementation. This commit applies the same change to support AssertJ on this builder, but for the multipart version. Any request builder can now use `multipart()` to "down cast" to a dedicated multipart request builder that contains the settings configured thus far. Closes gh-33027
This commit is contained in:
parent
f2137c99e5
commit
d76f37c90b
|
|
@ -31,10 +31,12 @@ import org.springframework.http.converter.GenericHttpMessageConverter;
|
|||
import org.springframework.http.converter.HttpMessageConverter;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockMultipartHttpServletRequest;
|
||||
import org.springframework.test.web.servlet.MockMvc;
|
||||
import org.springframework.test.web.servlet.MvcResult;
|
||||
import org.springframework.test.web.servlet.RequestBuilder;
|
||||
import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder;
|
||||
import org.springframework.test.web.servlet.request.AbstractMockMultipartHttpServletRequestBuilder;
|
||||
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
|
||||
import org.springframework.test.web.servlet.setup.DefaultMockMvcBuilder;
|
||||
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
|
||||
|
|
@ -389,8 +391,42 @@ public final class MockMvcTester {
|
|||
public final class MockMvcRequestBuilder extends AbstractMockHttpServletRequestBuilder<MockMvcRequestBuilder>
|
||||
implements AssertProvider<MvcTestResultAssert> {
|
||||
|
||||
private final HttpMethod httpMethod;
|
||||
|
||||
private MockMvcRequestBuilder(HttpMethod httpMethod) {
|
||||
super(httpMethod);
|
||||
this.httpMethod = httpMethod;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable file upload support using multipart.
|
||||
* @return a {@link MockMultipartMvcRequestBuilder} with the settings
|
||||
* configured thus far
|
||||
*/
|
||||
public MockMultipartMvcRequestBuilder multipart() {
|
||||
return new MockMultipartMvcRequestBuilder(this);
|
||||
}
|
||||
|
||||
public MvcTestResult exchange() {
|
||||
return perform(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MvcTestResultAssert assertThat() {
|
||||
return new MvcTestResultAssert(exchange(), MockMvcTester.this.jsonMessageConverter);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A builder for {@link MockMultipartHttpServletRequest} that supports AssertJ.
|
||||
*/
|
||||
public final class MockMultipartMvcRequestBuilder
|
||||
extends AbstractMockMultipartHttpServletRequestBuilder<MockMultipartMvcRequestBuilder>
|
||||
implements AssertProvider<MvcTestResultAssert> {
|
||||
|
||||
private MockMultipartMvcRequestBuilder(MockMvcRequestBuilder currentBuilder) {
|
||||
super(currentBuilder.httpMethod);
|
||||
merge(currentBuilder);
|
||||
}
|
||||
|
||||
public MvcTestResult exchange() {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,159 @@
|
|||
/*
|
||||
* Copyright 2002-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.test.web.servlet.request;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.charset.Charset;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
import jakarta.servlet.ServletContext;
|
||||
import jakarta.servlet.http.Part;
|
||||
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockMultipartFile;
|
||||
import org.springframework.mock.web.MockMultipartHttpServletRequest;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.FileCopyUtils;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
|
||||
/**
|
||||
* Base builder for {@link MockMultipartHttpServletRequest}.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
* @author Arjen Poutsma
|
||||
* @author Stephane Nicoll
|
||||
* @since 6.2
|
||||
* @param <B> a self reference to the builder type
|
||||
*/
|
||||
public abstract class AbstractMockMultipartHttpServletRequestBuilder<B extends AbstractMockMultipartHttpServletRequestBuilder<B>>
|
||||
extends AbstractMockHttpServletRequestBuilder<B> {
|
||||
|
||||
private final List<MockMultipartFile> files = new ArrayList<>();
|
||||
|
||||
private final MultiValueMap<String, Part> parts = new LinkedMultiValueMap<>();
|
||||
|
||||
|
||||
protected AbstractMockMultipartHttpServletRequestBuilder(HttpMethod httpMethod) {
|
||||
super(httpMethod);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new {@link MockMultipartFile} with the given content.
|
||||
* @param name the name of the file
|
||||
* @param content the content of the file
|
||||
*/
|
||||
public B file(String name, byte[] content) {
|
||||
this.files.add(new MockMultipartFile(name, content));
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add the given {@link MockMultipartFile}.
|
||||
* @param file the multipart file
|
||||
*/
|
||||
public B file(MockMultipartFile file) {
|
||||
this.files.add(file);
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add {@link Part} components to the request.
|
||||
* @param parts one or more parts to add
|
||||
* @since 5.0
|
||||
*/
|
||||
public B part(Part... parts) {
|
||||
Assert.notEmpty(parts, "'parts' must not be empty");
|
||||
for (Part part : parts) {
|
||||
this.parts.add(part.getName(), part);
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object merge(@Nullable Object parent) {
|
||||
if (parent == null) {
|
||||
return this;
|
||||
}
|
||||
if (parent instanceof AbstractMockHttpServletRequestBuilder<?>) {
|
||||
super.merge(parent);
|
||||
if (parent instanceof AbstractMockMultipartHttpServletRequestBuilder<?> parentBuilder) {
|
||||
this.files.addAll(parentBuilder.files);
|
||||
parentBuilder.parts.keySet().forEach(name ->
|
||||
this.parts.putIfAbsent(name, parentBuilder.parts.get(name)));
|
||||
}
|
||||
}
|
||||
else {
|
||||
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new {@link MockMultipartHttpServletRequest} based on the
|
||||
* supplied {@code ServletContext} and the {@code MockMultipartFiles}
|
||||
* added to this builder.
|
||||
*/
|
||||
@Override
|
||||
protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) {
|
||||
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext);
|
||||
Charset defaultCharset = (request.getCharacterEncoding() != null ?
|
||||
Charset.forName(request.getCharacterEncoding()) : StandardCharsets.UTF_8);
|
||||
|
||||
this.files.forEach(request::addFile);
|
||||
this.parts.values().stream().flatMap(Collection::stream).forEach(part -> {
|
||||
request.addPart(part);
|
||||
try {
|
||||
String name = part.getName();
|
||||
String filename = part.getSubmittedFileName();
|
||||
InputStream is = part.getInputStream();
|
||||
if (filename != null) {
|
||||
request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is));
|
||||
}
|
||||
else {
|
||||
InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset));
|
||||
String value = FileCopyUtils.copyToString(reader);
|
||||
request.addParameter(part.getName(), value);
|
||||
}
|
||||
}
|
||||
catch (IOException ex) {
|
||||
throw new IllegalStateException("Failed to read content for part " + part.getName(), ex);
|
||||
}
|
||||
});
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
private Charset getCharsetOrDefault(Part part, Charset defaultCharset) {
|
||||
if (part.getContentType() != null) {
|
||||
MediaType mediaType = MediaType.parseMediaType(part.getContentType());
|
||||
if (mediaType.getCharset() != null) {
|
||||
return mediaType.getCharset();
|
||||
}
|
||||
}
|
||||
return defaultCharset;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -16,42 +16,22 @@
|
|||
|
||||
package org.springframework.test.web.servlet.request;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.URI;
|
||||
import java.nio.charset.Charset;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
import jakarta.servlet.ServletContext;
|
||||
import jakarta.servlet.http.Part;
|
||||
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockMultipartFile;
|
||||
import org.springframework.mock.web.MockMultipartHttpServletRequest;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.FileCopyUtils;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
|
||||
/**
|
||||
* Default builder for {@link MockMultipartHttpServletRequest}.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
* @author Arjen Poutsma
|
||||
* @author Stephane Nicoll
|
||||
* @since 3.2
|
||||
*/
|
||||
public class MockMultipartHttpServletRequestBuilder extends AbstractMockHttpServletRequestBuilder<MockMultipartHttpServletRequestBuilder> {
|
||||
|
||||
private final List<MockMultipartFile> files = new ArrayList<>();
|
||||
|
||||
private final MultiValueMap<String, Part> parts = new LinkedMultiValueMap<>();
|
||||
public class MockMultipartHttpServletRequestBuilder
|
||||
extends AbstractMockMultipartHttpServletRequestBuilder<MockMultipartHttpServletRequestBuilder> {
|
||||
|
||||
|
||||
/**
|
||||
|
|
@ -98,101 +78,4 @@ public class MockMultipartHttpServletRequestBuilder extends AbstractMockHttpServ
|
|||
super.contentType(MediaType.MULTIPART_FORM_DATA);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Add a new {@link MockMultipartFile} with the given content.
|
||||
* @param name the name of the file
|
||||
* @param content the content of the file
|
||||
*/
|
||||
public MockMultipartHttpServletRequestBuilder file(String name, byte[] content) {
|
||||
this.files.add(new MockMultipartFile(name, content));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add the given {@link MockMultipartFile}.
|
||||
* @param file the multipart file
|
||||
*/
|
||||
public MockMultipartHttpServletRequestBuilder file(MockMultipartFile file) {
|
||||
this.files.add(file);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add {@link Part} components to the request.
|
||||
* @param parts one or more parts to add
|
||||
* @since 5.0
|
||||
*/
|
||||
public MockMultipartHttpServletRequestBuilder part(Part... parts) {
|
||||
Assert.notEmpty(parts, "'parts' must not be empty");
|
||||
for (Part part : parts) {
|
||||
this.parts.add(part.getName(), part);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object merge(@Nullable Object parent) {
|
||||
if (parent == null) {
|
||||
return this;
|
||||
}
|
||||
if (parent instanceof AbstractMockHttpServletRequestBuilder) {
|
||||
super.merge(parent);
|
||||
if (parent instanceof MockMultipartHttpServletRequestBuilder parentBuilder) {
|
||||
this.files.addAll(parentBuilder.files);
|
||||
parentBuilder.parts.keySet().forEach(name ->
|
||||
this.parts.putIfAbsent(name, parentBuilder.parts.get(name)));
|
||||
}
|
||||
}
|
||||
else {
|
||||
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new {@link MockMultipartHttpServletRequest} based on the
|
||||
* supplied {@code ServletContext} and the {@code MockMultipartFiles}
|
||||
* added to this builder.
|
||||
*/
|
||||
@Override
|
||||
protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) {
|
||||
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext);
|
||||
Charset defaultCharset = (request.getCharacterEncoding() != null ?
|
||||
Charset.forName(request.getCharacterEncoding()) : StandardCharsets.UTF_8);
|
||||
|
||||
this.files.forEach(request::addFile);
|
||||
this.parts.values().stream().flatMap(Collection::stream).forEach(part -> {
|
||||
request.addPart(part);
|
||||
try {
|
||||
String name = part.getName();
|
||||
String filename = part.getSubmittedFileName();
|
||||
InputStream is = part.getInputStream();
|
||||
if (filename != null) {
|
||||
request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is));
|
||||
}
|
||||
else {
|
||||
InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset));
|
||||
String value = FileCopyUtils.copyToString(reader);
|
||||
request.addParameter(part.getName(), value);
|
||||
}
|
||||
}
|
||||
catch (IOException ex) {
|
||||
throw new IllegalStateException("Failed to read content for part " + part.getName(), ex);
|
||||
}
|
||||
});
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
private Charset getCharsetOrDefault(Part part, Charset defaultCharset) {
|
||||
if (part.getContentType() != null) {
|
||||
MediaType mediaType = MediaType.parseMediaType(part.getContentType());
|
||||
if (mediaType.getCharset() != null) {
|
||||
return mediaType.getCharset();
|
||||
}
|
||||
}
|
||||
return defaultCharset;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import java.io.ByteArrayOutputStream;
|
|||
import java.io.PrintStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
|
|
@ -31,6 +32,7 @@ import jakarta.servlet.ServletException;
|
|||
import jakarta.servlet.http.Cookie;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import jakarta.servlet.http.Part;
|
||||
import jakarta.validation.Valid;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
|
|
@ -44,6 +46,8 @@ import org.springframework.core.io.ClassPathResource;
|
|||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.mock.web.MockMultipartFile;
|
||||
import org.springframework.mock.web.MockPart;
|
||||
import org.springframework.stereotype.Controller;
|
||||
import org.springframework.test.context.junit.jupiter.web.SpringJUnitWebConfig;
|
||||
import org.springframework.test.web.Person;
|
||||
|
|
@ -54,13 +58,19 @@ import org.springframework.web.bind.annotation.GetMapping;
|
|||
import org.springframework.web.bind.annotation.ModelAttribute;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.PutMapping;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RequestPart;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import org.springframework.web.bind.annotation.SessionAttributes;
|
||||
import org.springframework.web.context.WebApplicationContext;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.web.multipart.support.MissingServletRequestPartException;
|
||||
import org.springframework.web.server.ResponseStatusException;
|
||||
import org.springframework.web.servlet.DispatcherServlet;
|
||||
import org.springframework.web.servlet.HandlerInterceptor;
|
||||
import org.springframework.web.servlet.ModelAndView;
|
||||
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
|
||||
import org.springframework.web.servlet.mvc.support.RedirectAttributes;
|
||||
|
||||
|
|
@ -113,6 +123,41 @@ public class MockMvcTesterIntegrationTests {
|
|||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
class MultipartTests {
|
||||
|
||||
private final MockMultipartFile JSON_PART_FILE = new MockMultipartFile("json", "json", "application/json", """
|
||||
{
|
||||
"name": "test"
|
||||
}""".getBytes(StandardCharsets.UTF_8));
|
||||
|
||||
@Test
|
||||
void multipartWithPut() {
|
||||
MockMultipartFile part = new MockMultipartFile("file", "content.txt", null, "value".getBytes(StandardCharsets.UTF_8));
|
||||
assertThat(mvc.put().uri("/multipart-put").multipart().file(part).file(JSON_PART_FILE))
|
||||
.hasStatusOk()
|
||||
.hasViewName("index")
|
||||
.model().contains(entry("name", "file"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void multipartWithMissingPart() {
|
||||
assertThat(mvc.put().uri("/multipart-put").multipart().file(JSON_PART_FILE))
|
||||
.hasStatus(HttpStatus.BAD_REQUEST)
|
||||
.failure().isInstanceOfSatisfying(MissingServletRequestPartException.class,
|
||||
ex -> assertThat(ex.getRequestPartName()).isEqualTo("file"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void multipartWithNamedPart() {
|
||||
MockPart part = new MockPart("part", "content.txt", "value".getBytes(StandardCharsets.UTF_8));
|
||||
assertThat(mvc.post().uri("/part").multipart().part(part).file(JSON_PART_FILE))
|
||||
.hasStatusOk()
|
||||
.hasViewName("index")
|
||||
.model().contains(entry("part", "content.txt"), entry("name", "test"));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
class CookieTests {
|
||||
|
||||
|
|
@ -516,7 +561,7 @@ public class MockMvcTesterIntegrationTests {
|
|||
@Configuration
|
||||
@EnableWebMvc
|
||||
@Import({ TestController.class, PersonController.class, AsyncController.class,
|
||||
SessionController.class, ErrorController.class })
|
||||
MultipartController.class, SessionController.class, ErrorController.class })
|
||||
static class WebConfiguration {
|
||||
}
|
||||
|
||||
|
|
@ -564,6 +609,22 @@ public class MockMvcTesterIntegrationTests {
|
|||
}
|
||||
}
|
||||
|
||||
@Controller
|
||||
static class MultipartController {
|
||||
|
||||
@PostMapping("/part")
|
||||
ModelAndView part(@RequestPart Part part, @RequestPart Map<String, String> json) {
|
||||
Map<String, Object> model = new HashMap<>(json);
|
||||
model.put(part.getName(), part.getSubmittedFileName());
|
||||
return new ModelAndView("index", model);
|
||||
}
|
||||
|
||||
@PutMapping("/multipart-put")
|
||||
public ModelAndView multiPartViaHttpPut(@RequestParam MultipartFile file) {
|
||||
return new ModelAndView("index", Map.of("name", file.getName()));
|
||||
}
|
||||
}
|
||||
|
||||
@Controller
|
||||
@SessionAttributes("locale")
|
||||
static class SessionController {
|
||||
|
|
|
|||
Loading…
Reference in New Issue