diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java index c0f15fcc81c..73a0a3bd29d 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.test.web.servlet.request; import java.io.IOException; +import java.io.InputStream; import java.io.InputStreamReader; import java.net.URI; import java.nio.charset.Charset; @@ -38,7 +39,6 @@ import org.springframework.util.Assert; import org.springframework.util.FileCopyUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import org.springframework.web.multipart.MultipartFile; /** * Default builder for {@link MockMultipartHttpServletRequest}. @@ -144,47 +144,40 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque @Override protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) { MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext); + Charset defaultCharset = (request.getCharacterEncoding() != null ? + Charset.forName(request.getCharacterEncoding()) : StandardCharsets.UTF_8); + this.files.forEach(request::addFile); this.parts.values().stream().flatMap(Collection::stream).forEach(part -> { request.addPart(part); try { - MultipartFile file = asMultipartFile(part); - if (file != null) { - request.addFile(file); - return; + String name = part.getName(); + String filename = part.getSubmittedFileName(); + InputStream is = part.getInputStream(); + if (filename != null) { + request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is)); } - String value = toParameterValue(part); - if (value != null) { - request.addParameter(part.getName(), toParameterValue(part)); + else { + InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset)); + String value = FileCopyUtils.copyToString(reader); + request.addParameter(part.getName(), value); } } catch (IOException ex) { throw new IllegalStateException("Failed to read content for part " + part.getName(), ex); } }); + return request; } - @Nullable - private MultipartFile asMultipartFile(Part part) throws IOException { - String name = part.getName(); - String filename = part.getSubmittedFileName(); - if (filename != null) { - return new MockMultipartFile(name, filename, part.getContentType(), part.getInputStream()); + private Charset getCharsetOrDefault(Part part, Charset defaultCharset) { + if (part.getContentType() != null) { + MediaType mediaType = MediaType.parseMediaType(part.getContentType()); + if (mediaType.getCharset() != null) { + return mediaType.getCharset(); + } } - return null; + return defaultCharset; } - - @Nullable - private String toParameterValue(Part part) throws IOException { - String rawType = part.getContentType(); - MediaType mediaType = (rawType != null ? MediaType.parseMediaType(rawType) : MediaType.TEXT_PLAIN); - if (!mediaType.isCompatibleWith(MediaType.TEXT_PLAIN)) { - return null; - } - Charset charset = (mediaType.getCharset() != null ? mediaType.getCharset() : StandardCharsets.UTF_8); - InputStreamReader reader = new InputStreamReader(part.getInputStream(), charset); - return FileCopyUtils.copyToString(reader); - } - } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java index b1414d2bf4c..22263b90959 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ public class MockMultipartHttpServletRequestBuilderTests { assertThat(mockRequest.getParts()).extracting(Part::getName).containsExactly("name"); } - @Test // gh-26261 + @Test // gh-26261, gh-26400 void addFileWithoutFilename() throws Exception { MockPart jsonPart = new MockPart("data", "{\"node\":\"node\"}".getBytes(UTF_8)); jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); @@ -62,7 +62,8 @@ public class MockMultipartHttpServletRequestBuilderTests { .buildRequest(new MockServletContext()); assertThat(mockRequest.getFileMap()).containsOnlyKeys("file"); - assertThat(mockRequest.getParameterMap()).isEmpty(); + assertThat(mockRequest.getParameterMap()).hasSize(1); + assertThat(mockRequest.getParameter("data")).isEqualTo("{\"node\":\"node\"}"); assertThat(mockRequest.getParts()).extracting(Part::getName).containsExactly("data"); }