Attempt fallback Part resolution even without StandardMultipartHttpServletRequest

Closes gh-25829
This commit is contained in:
Juergen Hoeller 2020-10-12 18:25:55 +02:00
parent ec9de943ee
commit 69c330d905
2 changed files with 60 additions and 26 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,11 +22,13 @@ import java.io.InputStream;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.Part;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartException;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest; import org.springframework.web.multipart.MultipartHttpServletRequest;
@ -46,58 +48,75 @@ public class RequestPartServletServerHttpRequest extends ServletServerHttpReques
private final MultipartHttpServletRequest multipartRequest; private final MultipartHttpServletRequest multipartRequest;
private final String partName; private final String requestPartName;
private final HttpHeaders headers; private final HttpHeaders multipartHeaders;
/** /**
* Create a new {@code RequestPartServletServerHttpRequest} instance. * Create a new {@code RequestPartServletServerHttpRequest} instance.
* @param request the current servlet request * @param request the current servlet request
* @param partName the name of the part to adapt to the {@link ServerHttpRequest} contract * @param requestPartName the name of the part to adapt to the {@link ServerHttpRequest} contract
* @throws MissingServletRequestPartException if the request part cannot be found * @throws MissingServletRequestPartException if the request part cannot be found
* @throws MultipartException if MultipartHttpServletRequest cannot be initialized * @throws MultipartException if MultipartHttpServletRequest cannot be initialized
*/ */
public RequestPartServletServerHttpRequest(HttpServletRequest request, String partName) public RequestPartServletServerHttpRequest(HttpServletRequest request, String requestPartName)
throws MissingServletRequestPartException { throws MissingServletRequestPartException {
super(request); super(request);
this.multipartRequest = MultipartResolutionDelegate.asMultipartHttpServletRequest(request); this.multipartRequest = MultipartResolutionDelegate.asMultipartHttpServletRequest(request);
this.partName = partName; this.requestPartName = requestPartName;
HttpHeaders headers = this.multipartRequest.getMultipartHeaders(this.partName); HttpHeaders multipartHeaders = this.multipartRequest.getMultipartHeaders(this.requestPartName);
if (headers == null) { if (multipartHeaders == null) {
throw new MissingServletRequestPartException(partName); throw new MissingServletRequestPartException(requestPartName);
} }
this.headers = headers; this.multipartHeaders = multipartHeaders;
} }
@Override @Override
public HttpHeaders getHeaders() { public HttpHeaders getHeaders() {
return this.headers; return this.multipartHeaders;
} }
@Override @Override
public InputStream getBody() throws IOException { public InputStream getBody() throws IOException {
// Prefer Servlet Part resolution to cover file as well as parameter streams
if (this.multipartRequest instanceof StandardMultipartHttpServletRequest) { if (this.multipartRequest instanceof StandardMultipartHttpServletRequest) {
try { Part part = retrieveServletPart();
return this.multipartRequest.getPart(this.partName).getInputStream(); if (part != null) {
} return part.getInputStream();
catch (Exception ex) {
throw new MultipartException("Could not parse multipart servlet request", ex);
} }
} }
else {
MultipartFile file = this.multipartRequest.getFile(this.partName); // Spring-style distinction between MultipartFile and String parameters
if (file != null) { MultipartFile file = this.multipartRequest.getFile(this.requestPartName);
return file.getInputStream(); if (file != null) {
} return file.getInputStream();
else { }
String paramValue = this.multipartRequest.getParameter(this.partName); String paramValue = this.multipartRequest.getParameter(this.requestPartName);
return new ByteArrayInputStream(paramValue.getBytes(determineCharset())); if (paramValue != null) {
} return new ByteArrayInputStream(paramValue.getBytes(determineCharset()));
}
// Fallback: Servlet Part resolution even if not indicated
Part part = retrieveServletPart();
if (part != null) {
return part.getInputStream();
}
throw new IllegalStateException("No body available for request part '" + this.requestPartName + "'");
}
@Nullable
private Part retrieveServletPart() {
try {
return this.multipartRequest.getPart(this.requestPartName);
}
catch (Exception ex) {
throw new MultipartException("Failed to retrieve request part '" + this.requestPartName + "'", ex);
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -32,11 +32,13 @@ import org.springframework.util.FileCopyUtils;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.testfixture.servlet.MockMultipartFile; import org.springframework.web.testfixture.servlet.MockMultipartFile;
import org.springframework.web.testfixture.servlet.MockMultipartHttpServletRequest; import org.springframework.web.testfixture.servlet.MockMultipartHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockPart;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
/** /**
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @author Juergen Hoeller
*/ */
public class RequestPartServletServerHttpRequestTests { public class RequestPartServletServerHttpRequestTests {
@ -137,4 +139,17 @@ public class RequestPartServletServerHttpRequestTests {
assertThat(result).isEqualTo(bytes); assertThat(result).isEqualTo(bytes);
} }
@Test
public void getBodyViaRequestPart() throws Exception {
byte[] bytes = "content".getBytes("UTF-8");
MockPart mockPart = new MockPart("part", bytes);
mockPart.getHeaders().setContentType(MediaType.APPLICATION_JSON);
mockRequest.addPart(mockPart);
this.mockRequest.addPart(mockPart);
ServerHttpRequest request = new RequestPartServletServerHttpRequest(this.mockRequest, "part");
byte[] result = FileCopyUtils.copyToByteArray(request.getBody());
assertThat(result).isEqualTo(bytes);
}
} }