Attempt fallback Part resolution even without StandardMultipartHttpServletRequest
Closes gh-25829
This commit is contained in:
		
							parent
							
								
									ec9de943ee
								
							
						
					
					
						commit
						69c330d905
					
				|  | @ -1,5 +1,5 @@ | |||
| /* | ||||
|  * Copyright 2002-2017 the original author or authors. | ||||
|  * Copyright 2002-2020 the original author or authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  | @ -22,11 +22,13 @@ import java.io.InputStream; | |||
| import java.nio.charset.Charset; | ||||
| 
 | ||||
| import javax.servlet.http.HttpServletRequest; | ||||
| import javax.servlet.http.Part; | ||||
| 
 | ||||
| import org.springframework.http.HttpHeaders; | ||||
| import org.springframework.http.MediaType; | ||||
| import org.springframework.http.server.ServerHttpRequest; | ||||
| import org.springframework.http.server.ServletServerHttpRequest; | ||||
| import org.springframework.lang.Nullable; | ||||
| import org.springframework.web.multipart.MultipartException; | ||||
| import org.springframework.web.multipart.MultipartFile; | ||||
| import org.springframework.web.multipart.MultipartHttpServletRequest; | ||||
|  | @ -46,58 +48,75 @@ public class RequestPartServletServerHttpRequest extends ServletServerHttpReques | |||
| 
 | ||||
| 	private final MultipartHttpServletRequest multipartRequest; | ||||
| 
 | ||||
| 	private final String partName; | ||||
| 	private final String requestPartName; | ||||
| 
 | ||||
| 	private final HttpHeaders headers; | ||||
| 	private final HttpHeaders multipartHeaders; | ||||
| 
 | ||||
| 
 | ||||
| 	/** | ||||
| 	 * Create a new {@code RequestPartServletServerHttpRequest} instance. | ||||
| 	 * @param request the current servlet request | ||||
| 	 * @param partName the name of the part to adapt to the {@link ServerHttpRequest} contract | ||||
| 	 * @param requestPartName the name of the part to adapt to the {@link ServerHttpRequest} contract | ||||
| 	 * @throws MissingServletRequestPartException if the request part cannot be found | ||||
| 	 * @throws MultipartException if MultipartHttpServletRequest cannot be initialized | ||||
| 	 */ | ||||
| 	public RequestPartServletServerHttpRequest(HttpServletRequest request, String partName) | ||||
| 	public RequestPartServletServerHttpRequest(HttpServletRequest request, String requestPartName) | ||||
| 			throws MissingServletRequestPartException { | ||||
| 
 | ||||
| 		super(request); | ||||
| 
 | ||||
| 		this.multipartRequest = MultipartResolutionDelegate.asMultipartHttpServletRequest(request); | ||||
| 		this.partName = partName; | ||||
| 		this.requestPartName = requestPartName; | ||||
| 
 | ||||
| 		HttpHeaders headers = this.multipartRequest.getMultipartHeaders(this.partName); | ||||
| 		if (headers == null) { | ||||
| 			throw new MissingServletRequestPartException(partName); | ||||
| 		HttpHeaders multipartHeaders = this.multipartRequest.getMultipartHeaders(this.requestPartName); | ||||
| 		if (multipartHeaders == null) { | ||||
| 			throw new MissingServletRequestPartException(requestPartName); | ||||
| 		} | ||||
| 		this.headers = headers; | ||||
| 		this.multipartHeaders = multipartHeaders; | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	@Override | ||||
| 	public HttpHeaders getHeaders() { | ||||
| 		return this.headers; | ||||
| 		return this.multipartHeaders; | ||||
| 	} | ||||
| 
 | ||||
| 	@Override | ||||
| 	public InputStream getBody() throws IOException { | ||||
| 		// Prefer Servlet Part resolution to cover file as well as parameter streams | ||||
| 		if (this.multipartRequest instanceof StandardMultipartHttpServletRequest) { | ||||
| 			try { | ||||
| 				return this.multipartRequest.getPart(this.partName).getInputStream(); | ||||
| 			} | ||||
| 			catch (Exception ex) { | ||||
| 				throw new MultipartException("Could not parse multipart servlet request", ex); | ||||
| 			Part part = retrieveServletPart(); | ||||
| 			if (part != null) { | ||||
| 				return part.getInputStream(); | ||||
| 			} | ||||
| 		} | ||||
| 		else { | ||||
| 			MultipartFile file = this.multipartRequest.getFile(this.partName); | ||||
| 			if (file != null) { | ||||
| 				return file.getInputStream(); | ||||
| 			} | ||||
| 			else { | ||||
| 				String paramValue = this.multipartRequest.getParameter(this.partName); | ||||
| 				return new ByteArrayInputStream(paramValue.getBytes(determineCharset())); | ||||
| 			} | ||||
| 
 | ||||
| 		// Spring-style distinction between MultipartFile and String parameters | ||||
| 		MultipartFile file = this.multipartRequest.getFile(this.requestPartName); | ||||
| 		if (file != null) { | ||||
| 			return file.getInputStream(); | ||||
| 		} | ||||
| 		String paramValue = this.multipartRequest.getParameter(this.requestPartName); | ||||
| 		if (paramValue != null) { | ||||
| 			return new ByteArrayInputStream(paramValue.getBytes(determineCharset())); | ||||
| 		} | ||||
| 
 | ||||
| 		// Fallback: Servlet Part resolution even if not indicated | ||||
| 		Part part = retrieveServletPart(); | ||||
| 		if (part != null) { | ||||
| 			return part.getInputStream(); | ||||
| 		} | ||||
| 
 | ||||
| 		throw new IllegalStateException("No body available for request part '" + this.requestPartName + "'"); | ||||
| 	} | ||||
| 
 | ||||
| 	@Nullable | ||||
| 	private Part retrieveServletPart() { | ||||
| 		try { | ||||
| 			return this.multipartRequest.getPart(this.requestPartName); | ||||
| 		} | ||||
| 		catch (Exception ex) { | ||||
| 			throw new MultipartException("Failed to retrieve request part '" + this.requestPartName + "'", ex); | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| /* | ||||
|  * Copyright 2002-2019 the original author or authors. | ||||
|  * Copyright 2002-2020 the original author or authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  | @ -32,11 +32,13 @@ import org.springframework.util.FileCopyUtils; | |||
| import org.springframework.web.multipart.MultipartFile; | ||||
| import org.springframework.web.testfixture.servlet.MockMultipartFile; | ||||
| import org.springframework.web.testfixture.servlet.MockMultipartHttpServletRequest; | ||||
| import org.springframework.web.testfixture.servlet.MockPart; | ||||
| 
 | ||||
| import static org.assertj.core.api.Assertions.assertThat; | ||||
| 
 | ||||
| /** | ||||
|  * @author Rossen Stoyanchev | ||||
|  * @author Juergen Hoeller | ||||
|  */ | ||||
| public class RequestPartServletServerHttpRequestTests { | ||||
| 
 | ||||
|  | @ -137,4 +139,17 @@ public class RequestPartServletServerHttpRequestTests { | |||
| 		assertThat(result).isEqualTo(bytes); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void getBodyViaRequestPart() throws Exception { | ||||
| 		byte[] bytes = "content".getBytes("UTF-8"); | ||||
| 		MockPart mockPart = new MockPart("part", bytes); | ||||
| 		mockPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); | ||||
| 		mockRequest.addPart(mockPart); | ||||
| 		this.mockRequest.addPart(mockPart); | ||||
| 		ServerHttpRequest request = new RequestPartServletServerHttpRequest(this.mockRequest, "part"); | ||||
| 
 | ||||
| 		byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); | ||||
| 		assertThat(result).isEqualTo(bytes); | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue