Attempt fallback Part resolution even without StandardMultipartHttpServletRequest

Closes gh-25829
This commit is contained in:
Juergen Hoeller 2020-10-12 18:25:55 +02:00
parent ec9de943ee
commit 69c330d905
2 changed files with 60 additions and 26 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,11 +22,13 @@ import java.io.InputStream;
import java.nio.charset.Charset;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.Part;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.web.multipart.MultipartException;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;
@ -46,58 +48,75 @@ public class RequestPartServletServerHttpRequest extends ServletServerHttpReques
private final MultipartHttpServletRequest multipartRequest;
private final String partName;
private final String requestPartName;
private final HttpHeaders headers;
private final HttpHeaders multipartHeaders;
/**
* Create a new {@code RequestPartServletServerHttpRequest} instance.
* @param request the current servlet request
* @param partName the name of the part to adapt to the {@link ServerHttpRequest} contract
* @param requestPartName the name of the part to adapt to the {@link ServerHttpRequest} contract
* @throws MissingServletRequestPartException if the request part cannot be found
* @throws MultipartException if MultipartHttpServletRequest cannot be initialized
*/
public RequestPartServletServerHttpRequest(HttpServletRequest request, String partName)
public RequestPartServletServerHttpRequest(HttpServletRequest request, String requestPartName)
throws MissingServletRequestPartException {
super(request);
this.multipartRequest = MultipartResolutionDelegate.asMultipartHttpServletRequest(request);
this.partName = partName;
this.requestPartName = requestPartName;
HttpHeaders headers = this.multipartRequest.getMultipartHeaders(this.partName);
if (headers == null) {
throw new MissingServletRequestPartException(partName);
HttpHeaders multipartHeaders = this.multipartRequest.getMultipartHeaders(this.requestPartName);
if (multipartHeaders == null) {
throw new MissingServletRequestPartException(requestPartName);
}
this.headers = headers;
this.multipartHeaders = multipartHeaders;
}
@Override
public HttpHeaders getHeaders() {
return this.headers;
return this.multipartHeaders;
}
@Override
public InputStream getBody() throws IOException {
// Prefer Servlet Part resolution to cover file as well as parameter streams
if (this.multipartRequest instanceof StandardMultipartHttpServletRequest) {
try {
return this.multipartRequest.getPart(this.partName).getInputStream();
}
catch (Exception ex) {
throw new MultipartException("Could not parse multipart servlet request", ex);
Part part = retrieveServletPart();
if (part != null) {
return part.getInputStream();
}
}
else {
MultipartFile file = this.multipartRequest.getFile(this.partName);
if (file != null) {
return file.getInputStream();
}
else {
String paramValue = this.multipartRequest.getParameter(this.partName);
return new ByteArrayInputStream(paramValue.getBytes(determineCharset()));
}
// Spring-style distinction between MultipartFile and String parameters
MultipartFile file = this.multipartRequest.getFile(this.requestPartName);
if (file != null) {
return file.getInputStream();
}
String paramValue = this.multipartRequest.getParameter(this.requestPartName);
if (paramValue != null) {
return new ByteArrayInputStream(paramValue.getBytes(determineCharset()));
}
// Fallback: Servlet Part resolution even if not indicated
Part part = retrieveServletPart();
if (part != null) {
return part.getInputStream();
}
throw new IllegalStateException("No body available for request part '" + this.requestPartName + "'");
}
@Nullable
private Part retrieveServletPart() {
try {
return this.multipartRequest.getPart(this.requestPartName);
}
catch (Exception ex) {
throw new MultipartException("Failed to retrieve request part '" + this.requestPartName + "'", ex);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,11 +32,13 @@ import org.springframework.util.FileCopyUtils;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.testfixture.servlet.MockMultipartFile;
import org.springframework.web.testfixture.servlet.MockMultipartHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockPart;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Rossen Stoyanchev
* @author Juergen Hoeller
*/
public class RequestPartServletServerHttpRequestTests {
@ -137,4 +139,17 @@ public class RequestPartServletServerHttpRequestTests {
assertThat(result).isEqualTo(bytes);
}
@Test
public void getBodyViaRequestPart() throws Exception {
byte[] bytes = "content".getBytes("UTF-8");
MockPart mockPart = new MockPart("part", bytes);
mockPart.getHeaders().setContentType(MediaType.APPLICATION_JSON);
mockRequest.addPart(mockPart);
this.mockRequest.addPart(mockPart);
ServerHttpRequest request = new RequestPartServletServerHttpRequest(this.mockRequest, "part");
byte[] result = FileCopyUtils.copyToByteArray(request.getBody());
assertThat(result).isEqualTo(bytes);
}
}