Fix Netty4ClientHttpRequestFactory POST/PUT requests

This commit ensures that POST/PUT requests sent by the Netty client have
a Content-Length header set.

Integration tests have been refactored to use mockwebserver instead of
Jetty and have been parameterized to run on all available supported
clients.

Issue: SPR-14860
This commit is contained in:
Brian Clozel 2016-12-06 23:18:04 +01:00
parent 2c2de82ffb
commit ec8391a7fb
9 changed files with 406 additions and 602 deletions

View File

@ -48,6 +48,7 @@ import org.springframework.util.concurrent.SettableListenableFuture;
*
* @author Arjen Poutsma
* @author Rossen Stoyanchev
* @author Brian Clozel
* @since 4.1.2
*/
class Netty4ClientHttpRequest extends AbstractAsyncClientHttpRequest implements ClientHttpRequest {
@ -130,10 +131,15 @@ class Netty4ClientHttpRequest extends AbstractAsyncClientHttpRequest implements
io.netty.handler.codec.http.HttpMethod nettyMethod =
io.netty.handler.codec.http.HttpMethod.valueOf(this.method.name());
String authority = this.uri.getRawAuthority();
String path = this.uri.toString().substring(this.uri.toString().indexOf(authority) + authority.length());
FullHttpRequest nettyRequest = new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, nettyMethod, this.uri.toString(), this.body.buffer());
HttpVersion.HTTP_1_1, nettyMethod, path, this.body.buffer());
nettyRequest.headers().set(HttpHeaders.HOST, this.uri.getHost());
if (this.body.buffer().readableBytes() > 0) {
nettyRequest.headers().set(HttpHeaders.CONTENT_LENGTH, this.body.buffer().readableBytes());
}
nettyRequest.headers().set(HttpHeaders.CONNECTION, "close");
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {

View File

@ -39,7 +39,7 @@ import org.springframework.util.StreamUtils;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
public abstract class AbstractAsyncHttpRequestFactoryTestCase extends AbstractJettyServerTestCase {
public abstract class AbstractAsyncHttpRequestFactoryTestCase extends AbstractMockWebServerTestCase {
protected AsyncClientHttpRequestFactory factory;

View File

@ -39,7 +39,7 @@ import static org.junit.Assert.*;
/**
* @author Arjen Poutsma
*/
public abstract class AbstractHttpRequestFactoryTestCase extends AbstractJettyServerTestCase {
public abstract class AbstractHttpRequestFactoryTestCase extends AbstractMockWebServerTestCase {
protected ClientHttpRequestFactory factory;
@ -127,6 +127,8 @@ public abstract class AbstractHttpRequestFactoryTestCase extends AbstractJettySe
@Override
public void writeTo(OutputStream outputStream) throws IOException {
StreamUtils.copy(body, outputStream);
outputStream.flush();
outputStream.close();
}
});
}
@ -135,12 +137,7 @@ public abstract class AbstractHttpRequestFactoryTestCase extends AbstractJettySe
}
ClientHttpResponse response = request.execute();
try {
FileCopyUtils.copy(body, request.getBody());
}
finally {
response.close();
}
FileCopyUtils.copy(body, request.getBody());
}
@Test(expected = UnsupportedOperationException.class)

View File

@ -1,202 +0,0 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.client;
import java.io.IOException;
import java.io.InputStream;
import java.util.Enumeration;
import java.util.Map;
import javax.servlet.GenericServlet;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.NetworkConnector;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.springframework.util.StreamUtils;
import static org.junit.Assert.*;
/**
* @author Arjen Poutsma
* @author Sam Brannen
*/
public abstract class AbstractJettyServerTestCase {
private static Server jettyServer;
protected static String baseUrl;
@BeforeClass
public static void startJettyServer() throws Exception {
// Let server pick its own random, available port.
jettyServer = new Server(0);
ServletContextHandler handler = new ServletContextHandler();
handler.setContextPath("/");
handler.addServlet(new ServletHolder(new EchoServlet()), "/echo");
handler.addServlet(new ServletHolder(new ParameterServlet()), "/params");
handler.addServlet(new ServletHolder(new StatusServlet(200)), "/status/ok");
handler.addServlet(new ServletHolder(new StatusServlet(404)), "/status/notfound");
handler.addServlet(new ServletHolder(new MethodServlet("DELETE")), "/methods/delete");
handler.addServlet(new ServletHolder(new MethodServlet("GET")), "/methods/get");
handler.addServlet(new ServletHolder(new MethodServlet("HEAD")), "/methods/head");
handler.addServlet(new ServletHolder(new MethodServlet("OPTIONS")), "/methods/options");
handler.addServlet(new ServletHolder(new PostServlet()), "/methods/post");
handler.addServlet(new ServletHolder(new MethodServlet("PUT")), "/methods/put");
handler.addServlet(new ServletHolder(new MethodServlet("PATCH")), "/methods/patch");
jettyServer.setHandler(handler);
jettyServer.start();
Connector[] connectors = jettyServer.getConnectors();
NetworkConnector connector = (NetworkConnector) connectors[0];
baseUrl = "http://localhost:" + connector.getLocalPort();
}
@AfterClass
public static void stopJettyServer() throws Exception {
if (jettyServer != null) {
jettyServer.stop();
}
}
/**
* Servlet that sets a given status code.
*/
@SuppressWarnings("serial")
private static class StatusServlet extends GenericServlet {
private final int sc;
private StatusServlet(int sc) {
this.sc = sc;
}
@Override
public void service(ServletRequest request, ServletResponse response) throws
ServletException, IOException {
((HttpServletResponse) response).setStatus(sc);
}
}
@SuppressWarnings("serial")
private static class MethodServlet extends GenericServlet {
private final String method;
private MethodServlet(String method) {
this.method = method;
}
@Override
public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException {
HttpServletRequest httpReq = (HttpServletRequest) req;
assertEquals("Invalid HTTP method", method, httpReq.getMethod());
res.setContentLength(0);
((HttpServletResponse) res).setStatus(200);
}
}
@SuppressWarnings("serial")
private static class PostServlet extends MethodServlet {
private PostServlet() {
super("POST");
}
@Override
public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException {
super.service(req, res);
long contentLength = req.getContentLength();
if (contentLength != -1) {
InputStream in = req.getInputStream();
long byteCount = 0;
byte[] buffer = new byte[4096];
int bytesRead;
while ((bytesRead = in.read(buffer)) != -1) {
byteCount += bytesRead;
}
assertEquals("Invalid content-length", contentLength, byteCount);
}
}
}
@SuppressWarnings("serial")
private static class EchoServlet extends HttpServlet {
@Override
protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
echo(req, resp);
}
private void echo(HttpServletRequest request, HttpServletResponse response) throws IOException {
response.setStatus(HttpServletResponse.SC_OK);
response.setContentType(request.getContentType());
response.setContentLength(request.getContentLength());
for (Enumeration<String> e1 = request.getHeaderNames(); e1.hasMoreElements();) {
String headerName = e1.nextElement();
for (Enumeration<String> e2 = request.getHeaders(headerName); e2.hasMoreElements();) {
String headerValue = e2.nextElement();
response.addHeader(headerName, headerValue);
}
}
StreamUtils.copy(request.getInputStream(), response.getOutputStream());
}
}
@SuppressWarnings("serial")
private static class ParameterServlet extends HttpServlet {
@Override
protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
Map<String, String[]> parameters = req.getParameterMap();
assertEquals(2, parameters.size());
String[] values = parameters.get("param1");
assertEquals(1, values.length);
assertEquals("value", values[0]);
values = parameters.get("param2");
assertEquals(2, values.length);
assertEquals("value1", values[0]);
assertEquals("value2", values[1]);
resp.setStatus(200);
resp.setContentLength(0);
}
}
}

View File

@ -0,0 +1,95 @@
package org.springframework.http.client;
import java.util.Collections;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;
import org.springframework.http.MediaType;
import org.springframework.util.StringUtils;
import static org.hamcrest.MatcherAssert.assertThat;
/**
* @author Brian Clozel
*/
public class AbstractMockWebServerTestCase {
private MockWebServer server;
protected int port;
protected String baseUrl;
protected static final MediaType textContentType =
new MediaType("text", "plain", Collections.singletonMap("charset", "UTF-8"));
@Before
public void setUp() throws Exception {
this.server = new MockWebServer();
this.server.setDispatcher(new TestDispatcher());
this.server.start();
this.port = this.server.getPort();
this.baseUrl = "http://localhost:" + this.port;
}
@After
public void tearDown() throws Exception {
this.server.shutdown();
}
protected class TestDispatcher extends Dispatcher {
@Override
public MockResponse dispatch(RecordedRequest request) throws InterruptedException {
try {
if (request.getPath().equals("/echo")) {
MockResponse response = new MockResponse()
.setHeaders(request.getHeaders())
.setHeader("Content-Length", request.getBody().size())
.setResponseCode(200)
.setBody(request.getBody());
request.getBody().flush();
return response;
}
else if(request.getPath().equals("/status/ok")) {
return new MockResponse();
}
else if(request.getPath().equals("/status/notfound")) {
return new MockResponse().setResponseCode(404);
}
else if(request.getPath().startsWith("/params")) {
assertThat(request.getPath(), Matchers.containsString("param1=value"));
assertThat(request.getPath(), Matchers.containsString("param2=value1&param2=value2"));
return new MockResponse();
}
else if(request.getPath().equals("/methods/post")) {
assertThat(request.getMethod(), Matchers.is("POST"));
String transferEncoding = request.getHeader("Transfer-Encoding");
if(StringUtils.hasLength(transferEncoding)) {
assertThat(transferEncoding, Matchers.is("chunked"));
}
else {
long contentLength = Long.parseLong(request.getHeader("Content-Length"));
assertThat("Invalid content-length",
request.getBody().size(), Matchers.is(contentLength));
}
return new MockResponse().setResponseCode(200);
}
else if(request.getPath().startsWith("/methods/")) {
String expectedMethod = request.getPath().replace("/methods/","").toUpperCase();
assertThat(request.getMethod(), Matchers.is(expectedMethod));
return new MockResponse();
}
return new MockResponse().setResponseCode(404);
}
catch (Throwable exc) {
return new MockResponse().setResponseCode(500).setBody(exc.toString());
}
}
}
}

View File

@ -1,371 +0,0 @@
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.client;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import javax.servlet.GenericServlet;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileItemFactory;
import org.apache.commons.fileupload.FileUploadException;
import org.apache.commons.fileupload.disk.DiskFileItemFactory;
import org.apache.commons.fileupload.servlet.ServletFileUpload;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.NetworkConnector;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.springframework.http.MediaType;
import org.springframework.util.FileCopyUtils;
import static org.junit.Assert.*;
/**
* @author Arjen Poutsma
* @author Sam Brannen
*/
public class AbstractJettyServerTestCase {
protected static final String helloWorld = "H\u00e9llo W\u00f6rld";
protected static final MediaType textContentType =
new MediaType("text", "plain", Collections.singletonMap("charset", "UTF-8"));
protected static final MediaType jsonContentType =
new MediaType("application", "json", Collections.singletonMap("charset", "UTF-8"));
private static Server jettyServer;
protected static int port;
protected static String baseUrl;
@BeforeClass
public static void startJettyServer() throws Exception {
// Let server pick its own random, available port.
jettyServer = new Server(0);
ServletContextHandler handler = new ServletContextHandler();
byte[] bytes = helloWorld.getBytes(StandardCharsets.UTF_8);
handler.addServlet(new ServletHolder(new GetServlet(bytes, textContentType)), "/get");
handler.addServlet(new ServletHolder(new GetServlet(new byte[0], textContentType)), "/get/nothing");
handler.addServlet(new ServletHolder(new GetServlet(bytes, null)), "/get/nocontenttype");
handler.addServlet(
new ServletHolder(new PostServlet(helloWorld, "/post/1", bytes, textContentType)),
"/post");
handler.addServlet(
new ServletHolder(new JsonPostServlet("/jsonpost/1", jsonContentType)),
"/jsonpost");
handler.addServlet(new ServletHolder(new StatusCodeServlet(204)), "/status/nocontent");
handler.addServlet(new ServletHolder(new StatusCodeServlet(304)), "/status/notmodified");
handler.addServlet(new ServletHolder(new ErrorServlet(404)), "/status/notfound");
handler.addServlet(new ServletHolder(new ErrorServlet(500)), "/status/server");
handler.addServlet(new ServletHolder(new UriServlet()), "/uri/*");
handler.addServlet(new ServletHolder(new MultipartServlet()), "/multipart");
handler.addServlet(new ServletHolder(new FormServlet()), "/form");
handler.addServlet(new ServletHolder(new DeleteServlet()), "/delete");
handler.addServlet(new ServletHolder(new PatchServlet(helloWorld, bytes, textContentType)),
"/patch");
handler.addServlet(
new ServletHolder(new PutServlet(helloWorld, bytes, textContentType)),
"/put");
jettyServer.setHandler(handler);
jettyServer.start();
Connector[] connectors = jettyServer.getConnectors();
NetworkConnector connector = (NetworkConnector) connectors[0];
port = connector.getLocalPort();
baseUrl = "http://localhost:" + port;
}
@AfterClass
public static void stopJettyServer() throws Exception {
if (jettyServer != null) {
jettyServer.stop();
}
}
/** Servlet that sets the given status code. */
@SuppressWarnings("serial")
private static class StatusCodeServlet extends GenericServlet {
private final int sc;
public StatusCodeServlet(int sc) {
this.sc = sc;
}
@Override
public void service(ServletRequest request, ServletResponse response) throws IOException {
((HttpServletResponse) response).setStatus(sc);
}
}
/** Servlet that returns an error message for a given status code. */
@SuppressWarnings("serial")
private static class ErrorServlet extends GenericServlet {
private final int sc;
public ErrorServlet(int sc) {
this.sc = sc;
}
@Override
public void service(ServletRequest request, ServletResponse response) throws IOException {
((HttpServletResponse) response).sendError(sc);
}
}
@SuppressWarnings("serial")
private static class GetServlet extends HttpServlet {
private final byte[] buf;
private final MediaType contentType;
public GetServlet(byte[] buf, MediaType contentType) {
this.buf = buf;
this.contentType = contentType;
}
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException {
if (contentType != null) {
response.setContentType(contentType.toString());
}
response.setContentLength(buf.length);
FileCopyUtils.copy(buf, response.getOutputStream());
}
}
@SuppressWarnings("serial")
private static class PostServlet extends HttpServlet {
private final String content;
private final String location;
private final byte[] buf;
private final MediaType contentType;
public PostServlet(String content, String location, byte[] buf, MediaType contentType) {
this.content = content;
this.location = location;
this.buf = buf;
this.contentType = contentType;
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
assertTrue("Invalid request content-length", request.getContentLength() > 0);
assertNotNull("No content-type", request.getContentType());
String body = FileCopyUtils.copyToString(request.getReader());
assertEquals("Invalid request body", content, body);
response.setStatus(HttpServletResponse.SC_CREATED);
response.setHeader("Location", baseUrl + location);
response.setContentLength(buf.length);
response.setContentType(contentType.toString());
FileCopyUtils.copy(buf, response.getOutputStream());
}
}
@SuppressWarnings("serial")
private static class JsonPostServlet extends HttpServlet {
private final String location;
private final MediaType contentType;
public JsonPostServlet(String location, MediaType contentType) {
this.location = location;
this.contentType = contentType;
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
assertTrue("Invalid request content-length", request.getContentLength() > 0);
assertNotNull("No content-type", request.getContentType());
String body = FileCopyUtils.copyToString(request.getReader());
response.setStatus(HttpServletResponse.SC_CREATED);
response.setHeader("Location", baseUrl +location);
response.setContentType(contentType.toString());
byte[] bytes = body.getBytes("utf-8");
response.setContentLength(bytes.length);;
FileCopyUtils.copy(bytes, response.getOutputStream());
}
}
@SuppressWarnings("serial")
private static class PutServlet extends HttpServlet {
private final String s;
public PutServlet(String s, byte[] buf, MediaType contentType) {
this.s = s;
}
@Override
protected void doPut(HttpServletRequest request, HttpServletResponse response) throws IOException {
assertTrue("Invalid request content-length", request.getContentLength() > 0);
assertNotNull("No content-type", request.getContentType());
String body = FileCopyUtils.copyToString(request.getReader());
assertEquals("Invalid request body", s, body);
response.setStatus(HttpServletResponse.SC_ACCEPTED);
}
}
@SuppressWarnings("serial")
private static class UriServlet extends HttpServlet {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
resp.setContentType("text/plain");
resp.setCharacterEncoding("utf-8");
resp.getWriter().write(req.getRequestURI());
}
}
@SuppressWarnings("serial")
private static class MultipartServlet extends HttpServlet {
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
assertTrue(ServletFileUpload.isMultipartContent(req));
FileItemFactory factory = new DiskFileItemFactory();
ServletFileUpload upload = new ServletFileUpload(factory);
try {
List<FileItem> items = upload.parseRequest(req);
assertEquals(4, items.size());
FileItem item = items.get(0);
assertTrue(item.isFormField());
assertEquals("name 1", item.getFieldName());
assertEquals("value 1", item.getString());
item = items.get(1);
assertTrue(item.isFormField());
assertEquals("name 2", item.getFieldName());
assertEquals("value 2+1", item.getString());
item = items.get(2);
assertTrue(item.isFormField());
assertEquals("name 2", item.getFieldName());
assertEquals("value 2+2", item.getString());
item = items.get(3);
assertFalse(item.isFormField());
assertEquals("logo", item.getFieldName());
assertEquals("logo.jpg", item.getName());
assertEquals("image/jpeg", item.getContentType());
}
catch (FileUploadException ex) {
throw new ServletException(ex);
}
}
}
@SuppressWarnings("serial")
private static class FormServlet extends HttpServlet {
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
assertEquals(MediaType.APPLICATION_FORM_URLENCODED_VALUE, req.getContentType());
Map<String, String[]> parameters = req.getParameterMap();
assertEquals(2, parameters.size());
String[] values = parameters.get("name 1");
assertEquals(1, values.length);
assertEquals("value 1", values[0]);
values = parameters.get("name 2");
assertEquals(2, values.length);
assertEquals("value 2+1", values[0]);
assertEquals("value 2+2", values[1]);
}
}
@SuppressWarnings("serial")
private static class DeleteServlet extends HttpServlet {
@Override
protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws IOException {
resp.setStatus(200);
}
}
@SuppressWarnings("serial")
private static class PatchServlet extends GenericServlet {
private final String content;
private final byte[] buf;
private final MediaType contentType;
public PatchServlet(String content, byte[] buf, MediaType contentType) {
this.content = content;
this.buf = buf;
this.contentType = contentType;
}
@Override
public void service(ServletRequest req, ServletResponse res)
throws ServletException, IOException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
assertEquals("PATCH", request.getMethod());
assertTrue("Invalid request content-length", request.getContentLength() > 0);
assertNotNull("No content-type", request.getContentType());
String body = FileCopyUtils.copyToString(request.getReader());
assertEquals("Invalid request body", content, body);
response.setStatus(HttpServletResponse.SC_CREATED);
response.setContentLength(buf.length);
response.setContentType(contentType.toString());
FileCopyUtils.copy(buf, response.getOutputStream());
}
}
}

View File

@ -0,0 +1,254 @@
package org.springframework.web.client;
import java.io.EOFException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import okio.Buffer;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;
import org.springframework.http.MediaType;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
/**
* @author Brian Clozel
*/
public class AbstractMockWebServerTestCase {
protected static final String helloWorld = "H\u00e9llo W\u00f6rld";
private MockWebServer server;
protected int port;
protected String baseUrl;
protected static final MediaType textContentType =
new MediaType("text", "plain", Collections.singletonMap("charset", "UTF-8"));
@Before
public void setUp() throws Exception {
this.server = new MockWebServer();
this.server.setDispatcher(new TestDispatcher());
this.server.start();
this.port = this.server.getPort();
this.baseUrl = "http://localhost:" + this.port;
}
@After
public void tearDown() throws Exception {
this.server.shutdown();
}
protected class TestDispatcher extends Dispatcher {
@Override
public MockResponse dispatch(RecordedRequest request) throws InterruptedException {
try {
byte[] helloWorldBytes = helloWorld.getBytes(StandardCharsets.UTF_8);
if (request.getPath().equals("/get")) {
return getRequest(request, helloWorldBytes, textContentType.toString());
}
else if (request.getPath().equals("/get/nothing")) {
return getRequest(request, new byte[0], textContentType.toString());
}
else if (request.getPath().equals("/get/nocontenttype")) {
return getRequest(request, helloWorldBytes, null);
}
else if (request.getPath().equals("/post")) {
return postRequest(request, helloWorld, "/post/1", textContentType.toString(), helloWorldBytes);
}
else if (request.getPath().equals("/jsonpost")) {
return jsonPostRequest(request, "/jsonpost/1", "application/json; charset=utf-8");
}
else if (request.getPath().equals("/status/nocontent")) {
return new MockResponse().setResponseCode(204);
}
else if (request.getPath().equals("/status/notmodified")) {
return new MockResponse().setResponseCode(304);
}
else if (request.getPath().equals("/status/notfound")) {
return new MockResponse().setResponseCode(404);
}
else if (request.getPath().equals("/status/server")) {
return new MockResponse().setResponseCode(500);
}
else if (request.getPath().contains("/uri/")) {
return new MockResponse().setBody(request.getPath()).setHeader("Content-Type", "text/plain");
}
else if (request.getPath().equals("/multipart")) {
return multipartRequest(request);
}
else if (request.getPath().equals("/form")) {
return formRequest(request);
}
else if (request.getPath().equals("/delete")) {
return new MockResponse().setResponseCode(200);
}
else if (request.getPath().equals("/patch")) {
return patchRequest(request, helloWorld, textContentType.toString(), helloWorldBytes);
}
else if (request.getPath().equals("/put")) {
return putRequest(request, helloWorld);
}
return new MockResponse().setResponseCode(404);
}
catch (Throwable exc) {
return new MockResponse().setResponseCode(500).setBody(exc.toString());
}
}
}
private MockResponse getRequest(RecordedRequest request, byte[] body, String contentType) {
if(request.getMethod().equals("OPTIONS")) {
return new MockResponse().setResponseCode(200).setHeader("Allow", "GET, OPTIONS, HEAD, TRACE");
}
Buffer buf = new Buffer();
buf.write(body);
MockResponse response = new MockResponse()
.setHeader("Content-Length", body.length)
.setBody(buf)
.setResponseCode(200);
if (contentType != null) {
response = response.setHeader("Content-Type", contentType);
}
return response;
}
private MockResponse postRequest(RecordedRequest request, String expectedRequestContent,
String location, String contentType, byte[] responseBody) {
assertTrue("Invalid request content-length",
Integer.parseInt(request.getHeader("Content-Length")) > 0);
String requestContentType = request.getHeader("Content-Type");
assertNotNull("No content-type", requestContentType);
Charset charset = StandardCharsets.ISO_8859_1;
if(requestContentType.indexOf("charset=") > -1) {
String charsetName = requestContentType.split("charset=")[1];
charset = Charset.forName(charsetName);
}
assertEquals("Invalid request body", expectedRequestContent, request.getBody().readString(charset));
Buffer buf = new Buffer();
buf.write(responseBody);
return new MockResponse()
.setHeader("Location", baseUrl + location)
.setHeader("Content-Type", contentType)
.setHeader("Content-Length", responseBody.length)
.setBody(buf)
.setResponseCode(201);
}
private MockResponse jsonPostRequest(RecordedRequest request, String location, String contentType) {
assertTrue("Invalid request content-length",
Integer.parseInt(request.getHeader("Content-Length")) > 0);
assertNotNull("No content-type", request.getHeader("Content-Type"));
return new MockResponse()
.setHeader("Location", baseUrl + location)
.setHeader("Content-Type", contentType)
.setHeader("Content-Length", request.getBody().size())
.setBody(request.getBody())
.setResponseCode(201);
}
private MockResponse multipartRequest(RecordedRequest request) {
String contentType = request.getHeader("Content-Type");
assertTrue(contentType.startsWith("multipart/form-data"));
String boundary = contentType.split("boundary=")[1];
Buffer body = request.getBody();
try {
assertPart(body, "form-data", boundary, "name 1", "text/plain", "value 1");
assertPart(body, "form-data", boundary, "name 2", "text/plain", "value 2+1");
assertPart(body, "form-data", boundary, "name 2", "text/plain", "value 2+2");
assertFilePart(body, "form-data", boundary, "logo", "logo.jpg", "image/jpeg");
}
catch (EOFException e) {
throw new RuntimeException(e);
}
return new MockResponse().setResponseCode(200);
}
private void assertPart(Buffer buffer, String disposition, String boundary, String name,
String contentType, String value) throws EOFException {
assertTrue(buffer.readUtf8Line().contains("--" + boundary));
String line = buffer.readUtf8Line();
assertTrue(line.contains("Content-Disposition: "+ disposition));
assertTrue(line.contains("name=\""+ name + "\""));
assertTrue(buffer.readUtf8Line().startsWith("Content-Type: "+contentType));
assertTrue(buffer.readUtf8Line().equals("Content-Length: " + value.length()));
assertTrue(buffer.readUtf8Line().equals(""));
assertTrue(buffer.readUtf8Line().equals(value));
}
private void assertFilePart(Buffer buffer, String disposition, String boundary, String name,
String filename, String contentType) throws EOFException {
assertTrue(buffer.readUtf8Line().contains("--" + boundary));
String line = buffer.readUtf8Line();
assertTrue(line.contains("Content-Disposition: "+ disposition));
assertTrue(line.contains("name=\""+ name + "\""));
assertTrue(line.contains("filename=\""+ filename + "\""));
assertTrue(buffer.readUtf8Line().startsWith("Content-Type: "+contentType));
assertTrue(buffer.readUtf8Line().startsWith("Content-Length: "));
assertTrue(buffer.readUtf8Line().equals(""));
assertNotNull(buffer.readUtf8Line());
}
private MockResponse formRequest(RecordedRequest request) {
assertEquals("application/x-www-form-urlencoded", request.getHeader("Content-Type"));
String body = request.getBody().readUtf8();
assertThat(body, Matchers.containsString("name+1=value+1"));
assertThat(body, Matchers.containsString("name+2=value+2%2B1"));
assertThat(body, Matchers.containsString("name+2=value+2%2B2"));
return new MockResponse().setResponseCode(200);
}
private MockResponse patchRequest(RecordedRequest request, String expectedRequestContent,
String contentType, byte[] responseBody) {
assertEquals("PATCH", request.getMethod());
assertTrue("Invalid request content-length",
Integer.parseInt(request.getHeader("Content-Length")) > 0);
String requestContentType = request.getHeader("Content-Type");
assertNotNull("No content-type", requestContentType);
Charset charset = StandardCharsets.ISO_8859_1;
if(requestContentType.indexOf("charset=") > -1) {
String charsetName = requestContentType.split("charset=")[1];
charset = Charset.forName(charsetName);
}
assertEquals("Invalid request body", expectedRequestContent, request.getBody().readString(charset));
Buffer buf = new Buffer();
buf.write(responseBody);
return new MockResponse().setResponseCode(201)
.setHeader("Content-Length", responseBody.length)
.setHeader("Content-Type", contentType)
.setBody(buf);
}
private MockResponse putRequest(RecordedRequest request, String expectedRequestContent) {
assertTrue("Invalid request content-length",
Integer.parseInt(request.getHeader("Content-Length")) > 0);
String requestContentType = request.getHeader("Content-Type");
assertNotNull("No content-type", requestContentType);
Charset charset = StandardCharsets.ISO_8859_1;
if(requestContentType.indexOf("charset=") > -1) {
String charsetName = requestContentType.split("charset=")[1];
charset = Charset.forName(charsetName);
}
assertEquals("Invalid request body", expectedRequestContent, request.getBody().readString(charset));
return new MockResponse().setResponseCode(202);
}
}

View File

@ -18,7 +18,6 @@ package org.springframework.web.client;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.EnumSet;
@ -44,19 +43,25 @@ import org.springframework.http.client.AsyncClientHttpRequestExecution;
import org.springframework.http.client.AsyncClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsAsyncClientHttpRequestFactory;
import org.springframework.http.client.support.HttpRequestWrapper;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import static org.junit.Assert.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* @author Arjen Poutsma
* @author Sebastien Deleuze
*/
public class AsyncRestTemplateIntegrationTests extends AbstractJettyServerTestCase {
public class AsyncRestTemplateIntegrationTests extends AbstractMockWebServerTestCase {
private final AsyncRestTemplate template = new AsyncRestTemplate(
new HttpComponentsAsyncClientHttpRequestFactory());
@ -588,7 +593,7 @@ public class AsyncRestTemplateIntegrationTests extends AbstractJettyServerTestCa
public void getAndInterceptResponse() throws Exception {
RequestInterceptor interceptor = new RequestInterceptor();
template.setInterceptors(Collections.singletonList(interceptor));
ListenableFuture<ResponseEntity<String>> future = template.getForEntity("/get", String.class);
ListenableFuture<ResponseEntity<String>> future = template.getForEntity(baseUrl + "/get", String.class);
interceptor.latch.await(5, TimeUnit.SECONDS);
assertNotNull(interceptor.response);
@ -601,7 +606,7 @@ public class AsyncRestTemplateIntegrationTests extends AbstractJettyServerTestCa
public void getAndInterceptError() throws Exception {
RequestInterceptor interceptor = new RequestInterceptor();
template.setInterceptors(Collections.singletonList(interceptor));
template.getForEntity("/status/notfound", String.class);
template.getForEntity(baseUrl + "/status/notfound", String.class);
interceptor.latch.await(5, TimeUnit.SECONDS);
assertNotNull(interceptor.response);
@ -627,18 +632,6 @@ public class AsyncRestTemplateIntegrationTests extends AbstractJettyServerTestCa
public ListenableFuture<ClientHttpResponse> intercept(HttpRequest request, byte[] body,
AsyncClientHttpRequestExecution execution) throws IOException {
request = new HttpRequestWrapper(request) {
@Override
public URI getURI() {
try {
return new URI(baseUrl + super.getURI().toString());
}
catch (URISyntaxException ex) {
throw new IllegalStateException(ex);
}
}
};
ListenableFuture<ClientHttpResponse> future = execution.executeAsync(request, body);
future.addCallback(
resp -> {

View File

@ -21,6 +21,7 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
@ -28,7 +29,12 @@ import java.util.Set;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.fasterxml.jackson.annotation.JsonView;
import org.hamcrest.Matchers;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.ClassPathResource;
@ -40,7 +46,11 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.Netty4ClientHttpRequestFactory;
import org.springframework.http.client.OkHttp3ClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.converter.json.MappingJacksonValue;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
@ -55,11 +65,30 @@ import static org.junit.Assert.fail;
/**
* @author Arjen Poutsma
* @author Brian Clozel
*/
public class RestTemplateIntegrationTests extends AbstractJettyServerTestCase {
@RunWith(Parameterized.class)
public class RestTemplateIntegrationTests extends AbstractMockWebServerTestCase {
private final RestTemplate template = new RestTemplate(new HttpComponentsClientHttpRequestFactory());
private RestTemplate template;
@Parameterized.Parameter
public ClientHttpRequestFactory clientHttpRequestFactory;
@Parameterized.Parameters
public static Iterable<? extends ClientHttpRequestFactory> data() {
return Arrays.asList(
new HttpComponentsClientHttpRequestFactory(),
new Netty4ClientHttpRequestFactory(),
new OkHttp3ClientHttpRequestFactory(),
new SimpleClientHttpRequestFactory()
);
}
@Before
public void setUpClient() {
this.template = new RestTemplate(this.clientHttpRequestFactory);
}
@Test
public void getString() {
@ -131,6 +160,9 @@ public class RestTemplateIntegrationTests extends AbstractJettyServerTestCase {
@Test
public void patchForObject() throws URISyntaxException {
// JDK client does not support the PATCH method
Assume.assumeThat(this.clientHttpRequestFactory,
Matchers.not(Matchers.instanceOf(SimpleClientHttpRequestFactory.class)));
String s = template.patchForObject(baseUrl + "/{method}", helloWorld, String.class, "patch");
assertEquals("Invalid content", helloWorld, s);
}