diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java index 901c7a6328a..8e998b91cd9 100644 --- a/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,11 +22,13 @@ import java.io.InputStream; import java.nio.charset.Charset; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.Part; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.lang.Nullable; import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartHttpServletRequest; @@ -46,58 +48,75 @@ public class RequestPartServletServerHttpRequest extends ServletServerHttpReques private final MultipartHttpServletRequest multipartRequest; - private final String partName; + private final String requestPartName; - private final HttpHeaders headers; + private final HttpHeaders multipartHeaders; /** * Create a new {@code RequestPartServletServerHttpRequest} instance. * @param request the current servlet request - * @param partName the name of the part to adapt to the {@link ServerHttpRequest} contract + * @param requestPartName the name of the part to adapt to the {@link ServerHttpRequest} contract * @throws MissingServletRequestPartException if the request part cannot be found * @throws MultipartException if MultipartHttpServletRequest cannot be initialized */ - public RequestPartServletServerHttpRequest(HttpServletRequest request, String partName) + public RequestPartServletServerHttpRequest(HttpServletRequest request, String requestPartName) throws MissingServletRequestPartException { super(request); this.multipartRequest = MultipartResolutionDelegate.asMultipartHttpServletRequest(request); - this.partName = partName; + this.requestPartName = requestPartName; - HttpHeaders headers = this.multipartRequest.getMultipartHeaders(this.partName); - if (headers == null) { - throw new MissingServletRequestPartException(partName); + HttpHeaders multipartHeaders = this.multipartRequest.getMultipartHeaders(this.requestPartName); + if (multipartHeaders == null) { + throw new MissingServletRequestPartException(requestPartName); } - this.headers = headers; + this.multipartHeaders = multipartHeaders; } @Override public HttpHeaders getHeaders() { - return this.headers; + return this.multipartHeaders; } @Override public InputStream getBody() throws IOException { + // Prefer Servlet Part resolution to cover file as well as parameter streams if (this.multipartRequest instanceof StandardMultipartHttpServletRequest) { - try { - return this.multipartRequest.getPart(this.partName).getInputStream(); - } - catch (Exception ex) { - throw new MultipartException("Could not parse multipart servlet request", ex); + Part part = retrieveServletPart(); + if (part != null) { + return part.getInputStream(); } } - else { - MultipartFile file = this.multipartRequest.getFile(this.partName); - if (file != null) { - return file.getInputStream(); - } - else { - String paramValue = this.multipartRequest.getParameter(this.partName); - return new ByteArrayInputStream(paramValue.getBytes(determineCharset())); - } + + // Spring-style distinction between MultipartFile and String parameters + MultipartFile file = this.multipartRequest.getFile(this.requestPartName); + if (file != null) { + return file.getInputStream(); + } + String paramValue = this.multipartRequest.getParameter(this.requestPartName); + if (paramValue != null) { + return new ByteArrayInputStream(paramValue.getBytes(determineCharset())); + } + + // Fallback: Servlet Part resolution even if not indicated + Part part = retrieveServletPart(); + if (part != null) { + return part.getInputStream(); + } + + throw new IllegalStateException("No body available for request part '" + this.requestPartName + "'"); + } + + @Nullable + private Part retrieveServletPart() { + try { + return this.multipartRequest.getPart(this.requestPartName); + } + catch (Exception ex) { + throw new MultipartException("Failed to retrieve request part '" + this.requestPartName + "'", ex); } } diff --git a/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java index 9c47afeb679..05076d1e011 100644 --- a/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java +++ b/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,11 +32,13 @@ import org.springframework.util.FileCopyUtils; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.testfixture.servlet.MockMultipartFile; import org.springframework.web.testfixture.servlet.MockMultipartHttpServletRequest; +import org.springframework.web.testfixture.servlet.MockPart; import static org.assertj.core.api.Assertions.assertThat; /** * @author Rossen Stoyanchev + * @author Juergen Hoeller */ public class RequestPartServletServerHttpRequestTests { @@ -137,4 +139,17 @@ public class RequestPartServletServerHttpRequestTests { assertThat(result).isEqualTo(bytes); } + @Test + public void getBodyViaRequestPart() throws Exception { + byte[] bytes = "content".getBytes("UTF-8"); + MockPart mockPart = new MockPart("part", bytes); + mockPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); + mockRequest.addPart(mockPart); + this.mockRequest.addPart(mockPart); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(this.mockRequest, "part"); + + byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); + assertThat(result).isEqualTo(bytes); + } + }