From ec8391a7fbc8ba67d1a04c9b93408230345fec36 Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Tue, 6 Dec 2016 23:18:04 +0100 Subject: [PATCH] Fix Netty4ClientHttpRequestFactory POST/PUT requests This commit ensures that POST/PUT requests sent by the Netty client have a Content-Length header set. Integration tests have been refactored to use mockwebserver instead of Jetty and have been parameterized to run on all available supported clients. Issue: SPR-14860 --- .../http/client/Netty4ClientHttpRequest.java | 8 +- ...stractAsyncHttpRequestFactoryTestCase.java | 2 +- .../AbstractHttpRequestFactoryTestCase.java | 11 +- .../client/AbstractJettyServerTestCase.java | 202 ---------- .../client/AbstractMockWebServerTestCase.java | 95 +++++ .../client/AbstractJettyServerTestCase.java | 371 ------------------ .../client/AbstractMockWebServerTestCase.java | 254 ++++++++++++ .../AsyncRestTemplateIntegrationTests.java | 29 +- .../client/RestTemplateIntegrationTests.java | 36 +- 9 files changed, 406 insertions(+), 602 deletions(-) delete mode 100644 spring-web/src/test/java/org/springframework/http/client/AbstractJettyServerTestCase.java create mode 100644 spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTestCase.java delete mode 100644 spring-web/src/test/java/org/springframework/web/client/AbstractJettyServerTestCase.java create mode 100644 spring-web/src/test/java/org/springframework/web/client/AbstractMockWebServerTestCase.java diff --git a/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpRequest.java index 376092237d0..24e773aadaf 100644 --- a/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/Netty4ClientHttpRequest.java @@ -48,6 +48,7 @@ import org.springframework.util.concurrent.SettableListenableFuture; * * @author Arjen Poutsma * @author Rossen Stoyanchev + * @author Brian Clozel * @since 4.1.2 */ class Netty4ClientHttpRequest extends AbstractAsyncClientHttpRequest implements ClientHttpRequest { @@ -130,10 +131,15 @@ class Netty4ClientHttpRequest extends AbstractAsyncClientHttpRequest implements io.netty.handler.codec.http.HttpMethod nettyMethod = io.netty.handler.codec.http.HttpMethod.valueOf(this.method.name()); + String authority = this.uri.getRawAuthority(); + String path = this.uri.toString().substring(this.uri.toString().indexOf(authority) + authority.length()); FullHttpRequest nettyRequest = new DefaultFullHttpRequest( - HttpVersion.HTTP_1_1, nettyMethod, this.uri.toString(), this.body.buffer()); + HttpVersion.HTTP_1_1, nettyMethod, path, this.body.buffer()); nettyRequest.headers().set(HttpHeaders.HOST, this.uri.getHost()); + if (this.body.buffer().readableBytes() > 0) { + nettyRequest.headers().set(HttpHeaders.CONTENT_LENGTH, this.body.buffer().readableBytes()); + } nettyRequest.headers().set(HttpHeaders.CONNECTION, "close"); for (Map.Entry> entry : headers.entrySet()) { diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractAsyncHttpRequestFactoryTestCase.java b/spring-web/src/test/java/org/springframework/http/client/AbstractAsyncHttpRequestFactoryTestCase.java index a40a42d89a8..14f17f27add 100644 --- a/spring-web/src/test/java/org/springframework/http/client/AbstractAsyncHttpRequestFactoryTestCase.java +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractAsyncHttpRequestFactoryTestCase.java @@ -39,7 +39,7 @@ import org.springframework.util.StreamUtils; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.util.concurrent.ListenableFutureCallback; -public abstract class AbstractAsyncHttpRequestFactoryTestCase extends AbstractJettyServerTestCase { +public abstract class AbstractAsyncHttpRequestFactoryTestCase extends AbstractMockWebServerTestCase { protected AsyncClientHttpRequestFactory factory; diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java b/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java index 267bc6b9d22..c81f183ee0d 100644 --- a/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java @@ -39,7 +39,7 @@ import static org.junit.Assert.*; /** * @author Arjen Poutsma */ -public abstract class AbstractHttpRequestFactoryTestCase extends AbstractJettyServerTestCase { +public abstract class AbstractHttpRequestFactoryTestCase extends AbstractMockWebServerTestCase { protected ClientHttpRequestFactory factory; @@ -127,6 +127,8 @@ public abstract class AbstractHttpRequestFactoryTestCase extends AbstractJettySe @Override public void writeTo(OutputStream outputStream) throws IOException { StreamUtils.copy(body, outputStream); + outputStream.flush(); + outputStream.close(); } }); } @@ -135,12 +137,7 @@ public abstract class AbstractHttpRequestFactoryTestCase extends AbstractJettySe } ClientHttpResponse response = request.execute(); - try { - FileCopyUtils.copy(body, request.getBody()); - } - finally { - response.close(); - } + FileCopyUtils.copy(body, request.getBody()); } @Test(expected = UnsupportedOperationException.class) diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractJettyServerTestCase.java b/spring-web/src/test/java/org/springframework/http/client/AbstractJettyServerTestCase.java deleted file mode 100644 index 24ec7688276..00000000000 --- a/spring-web/src/test/java/org/springframework/http/client/AbstractJettyServerTestCase.java +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Copyright 2002-2015 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.http.client; - -import java.io.IOException; -import java.io.InputStream; -import java.util.Enumeration; -import java.util.Map; - -import javax.servlet.GenericServlet; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServlet; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import org.eclipse.jetty.server.Connector; -import org.eclipse.jetty.server.NetworkConnector; -import org.eclipse.jetty.server.Server; -import org.eclipse.jetty.servlet.ServletContextHandler; -import org.eclipse.jetty.servlet.ServletHolder; - -import org.junit.AfterClass; -import org.junit.BeforeClass; - -import org.springframework.util.StreamUtils; - -import static org.junit.Assert.*; - -/** - * @author Arjen Poutsma - * @author Sam Brannen - */ -public abstract class AbstractJettyServerTestCase { - - private static Server jettyServer; - - protected static String baseUrl; - - @BeforeClass - public static void startJettyServer() throws Exception { - - // Let server pick its own random, available port. - jettyServer = new Server(0); - - ServletContextHandler handler = new ServletContextHandler(); - handler.setContextPath("/"); - - handler.addServlet(new ServletHolder(new EchoServlet()), "/echo"); - handler.addServlet(new ServletHolder(new ParameterServlet()), "/params"); - handler.addServlet(new ServletHolder(new StatusServlet(200)), "/status/ok"); - handler.addServlet(new ServletHolder(new StatusServlet(404)), "/status/notfound"); - handler.addServlet(new ServletHolder(new MethodServlet("DELETE")), "/methods/delete"); - handler.addServlet(new ServletHolder(new MethodServlet("GET")), "/methods/get"); - handler.addServlet(new ServletHolder(new MethodServlet("HEAD")), "/methods/head"); - handler.addServlet(new ServletHolder(new MethodServlet("OPTIONS")), "/methods/options"); - handler.addServlet(new ServletHolder(new PostServlet()), "/methods/post"); - handler.addServlet(new ServletHolder(new MethodServlet("PUT")), "/methods/put"); - handler.addServlet(new ServletHolder(new MethodServlet("PATCH")), "/methods/patch"); - - jettyServer.setHandler(handler); - jettyServer.start(); - - Connector[] connectors = jettyServer.getConnectors(); - NetworkConnector connector = (NetworkConnector) connectors[0]; - baseUrl = "http://localhost:" + connector.getLocalPort(); - } - - @AfterClass - public static void stopJettyServer() throws Exception { - if (jettyServer != null) { - jettyServer.stop(); - } - } - - - /** - * Servlet that sets a given status code. - */ - @SuppressWarnings("serial") - private static class StatusServlet extends GenericServlet { - - private final int sc; - - private StatusServlet(int sc) { - this.sc = sc; - } - - @Override - public void service(ServletRequest request, ServletResponse response) throws - ServletException, IOException { - ((HttpServletResponse) response).setStatus(sc); - } - } - - - @SuppressWarnings("serial") - private static class MethodServlet extends GenericServlet { - - private final String method; - - private MethodServlet(String method) { - this.method = method; - } - - @Override - public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException { - HttpServletRequest httpReq = (HttpServletRequest) req; - assertEquals("Invalid HTTP method", method, httpReq.getMethod()); - res.setContentLength(0); - ((HttpServletResponse) res).setStatus(200); - } - } - - - @SuppressWarnings("serial") - private static class PostServlet extends MethodServlet { - - private PostServlet() { - super("POST"); - } - - @Override - public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException { - super.service(req, res); - long contentLength = req.getContentLength(); - if (contentLength != -1) { - InputStream in = req.getInputStream(); - long byteCount = 0; - byte[] buffer = new byte[4096]; - int bytesRead; - while ((bytesRead = in.read(buffer)) != -1) { - byteCount += bytesRead; - } - assertEquals("Invalid content-length", contentLength, byteCount); - } - } - } - - - @SuppressWarnings("serial") - private static class EchoServlet extends HttpServlet { - - @Override - protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - echo(req, resp); - } - - private void echo(HttpServletRequest request, HttpServletResponse response) throws IOException { - response.setStatus(HttpServletResponse.SC_OK); - response.setContentType(request.getContentType()); - response.setContentLength(request.getContentLength()); - for (Enumeration e1 = request.getHeaderNames(); e1.hasMoreElements();) { - String headerName = e1.nextElement(); - for (Enumeration e2 = request.getHeaders(headerName); e2.hasMoreElements();) { - String headerValue = e2.nextElement(); - response.addHeader(headerName, headerValue); - } - } - StreamUtils.copy(request.getInputStream(), response.getOutputStream()); - } - } - - - @SuppressWarnings("serial") - private static class ParameterServlet extends HttpServlet { - - @Override - protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - Map parameters = req.getParameterMap(); - assertEquals(2, parameters.size()); - - String[] values = parameters.get("param1"); - assertEquals(1, values.length); - assertEquals("value", values[0]); - - values = parameters.get("param2"); - assertEquals(2, values.length); - assertEquals("value1", values[0]); - assertEquals("value2", values[1]); - - resp.setStatus(200); - resp.setContentLength(0); - } - } - -} diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTestCase.java b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTestCase.java new file mode 100644 index 00000000000..423f66ed6c0 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTestCase.java @@ -0,0 +1,95 @@ +package org.springframework.http.client; + +import java.util.Collections; + +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import org.springframework.http.MediaType; +import org.springframework.util.StringUtils; + +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * @author Brian Clozel + */ +public class AbstractMockWebServerTestCase { + + private MockWebServer server; + + protected int port; + + protected String baseUrl; + + protected static final MediaType textContentType = + new MediaType("text", "plain", Collections.singletonMap("charset", "UTF-8")); + + @Before + public void setUp() throws Exception { + this.server = new MockWebServer(); + this.server.setDispatcher(new TestDispatcher()); + this.server.start(); + this.port = this.server.getPort(); + this.baseUrl = "http://localhost:" + this.port; + } + + @After + public void tearDown() throws Exception { + this.server.shutdown(); + } + + protected class TestDispatcher extends Dispatcher { + @Override + public MockResponse dispatch(RecordedRequest request) throws InterruptedException { + try { + if (request.getPath().equals("/echo")) { + MockResponse response = new MockResponse() + .setHeaders(request.getHeaders()) + .setHeader("Content-Length", request.getBody().size()) + .setResponseCode(200) + .setBody(request.getBody()); + request.getBody().flush(); + return response; + } + else if(request.getPath().equals("/status/ok")) { + return new MockResponse(); + } + else if(request.getPath().equals("/status/notfound")) { + return new MockResponse().setResponseCode(404); + } + else if(request.getPath().startsWith("/params")) { + assertThat(request.getPath(), Matchers.containsString("param1=value")); + assertThat(request.getPath(), Matchers.containsString("param2=value1¶m2=value2")); + return new MockResponse(); + } + else if(request.getPath().equals("/methods/post")) { + assertThat(request.getMethod(), Matchers.is("POST")); + String transferEncoding = request.getHeader("Transfer-Encoding"); + if(StringUtils.hasLength(transferEncoding)) { + assertThat(transferEncoding, Matchers.is("chunked")); + } + else { + long contentLength = Long.parseLong(request.getHeader("Content-Length")); + assertThat("Invalid content-length", + request.getBody().size(), Matchers.is(contentLength)); + } + return new MockResponse().setResponseCode(200); + } + else if(request.getPath().startsWith("/methods/")) { + String expectedMethod = request.getPath().replace("/methods/","").toUpperCase(); + assertThat(request.getMethod(), Matchers.is(expectedMethod)); + return new MockResponse(); + } + return new MockResponse().setResponseCode(404); + } + catch (Throwable exc) { + return new MockResponse().setResponseCode(500).setBody(exc.toString()); + } + } + } +} diff --git a/spring-web/src/test/java/org/springframework/web/client/AbstractJettyServerTestCase.java b/spring-web/src/test/java/org/springframework/web/client/AbstractJettyServerTestCase.java deleted file mode 100644 index 78c356e9b12..00000000000 --- a/spring-web/src/test/java/org/springframework/web/client/AbstractJettyServerTestCase.java +++ /dev/null @@ -1,371 +0,0 @@ -/* - * Copyright 2002-2016 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.client; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import javax.servlet.GenericServlet; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServlet; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import org.apache.commons.fileupload.FileItem; -import org.apache.commons.fileupload.FileItemFactory; -import org.apache.commons.fileupload.FileUploadException; -import org.apache.commons.fileupload.disk.DiskFileItemFactory; -import org.apache.commons.fileupload.servlet.ServletFileUpload; -import org.eclipse.jetty.server.Connector; -import org.eclipse.jetty.server.NetworkConnector; -import org.eclipse.jetty.server.Server; -import org.eclipse.jetty.servlet.ServletContextHandler; -import org.eclipse.jetty.servlet.ServletHolder; -import org.junit.AfterClass; -import org.junit.BeforeClass; - -import org.springframework.http.MediaType; -import org.springframework.util.FileCopyUtils; - -import static org.junit.Assert.*; - -/** - * @author Arjen Poutsma - * @author Sam Brannen - */ -public class AbstractJettyServerTestCase { - - protected static final String helloWorld = "H\u00e9llo W\u00f6rld"; - - protected static final MediaType textContentType = - new MediaType("text", "plain", Collections.singletonMap("charset", "UTF-8")); - - protected static final MediaType jsonContentType = - new MediaType("application", "json", Collections.singletonMap("charset", "UTF-8")); - - - private static Server jettyServer; - - protected static int port; - - protected static String baseUrl; - - - @BeforeClass - public static void startJettyServer() throws Exception { - // Let server pick its own random, available port. - jettyServer = new Server(0); - - ServletContextHandler handler = new ServletContextHandler(); - byte[] bytes = helloWorld.getBytes(StandardCharsets.UTF_8); - handler.addServlet(new ServletHolder(new GetServlet(bytes, textContentType)), "/get"); - handler.addServlet(new ServletHolder(new GetServlet(new byte[0], textContentType)), "/get/nothing"); - handler.addServlet(new ServletHolder(new GetServlet(bytes, null)), "/get/nocontenttype"); - handler.addServlet( - new ServletHolder(new PostServlet(helloWorld, "/post/1", bytes, textContentType)), - "/post"); - handler.addServlet( - new ServletHolder(new JsonPostServlet("/jsonpost/1", jsonContentType)), - "/jsonpost"); - handler.addServlet(new ServletHolder(new StatusCodeServlet(204)), "/status/nocontent"); - handler.addServlet(new ServletHolder(new StatusCodeServlet(304)), "/status/notmodified"); - handler.addServlet(new ServletHolder(new ErrorServlet(404)), "/status/notfound"); - handler.addServlet(new ServletHolder(new ErrorServlet(500)), "/status/server"); - handler.addServlet(new ServletHolder(new UriServlet()), "/uri/*"); - handler.addServlet(new ServletHolder(new MultipartServlet()), "/multipart"); - handler.addServlet(new ServletHolder(new FormServlet()), "/form"); - handler.addServlet(new ServletHolder(new DeleteServlet()), "/delete"); - handler.addServlet(new ServletHolder(new PatchServlet(helloWorld, bytes, textContentType)), - "/patch"); - handler.addServlet( - new ServletHolder(new PutServlet(helloWorld, bytes, textContentType)), - "/put"); - - jettyServer.setHandler(handler); - jettyServer.start(); - - Connector[] connectors = jettyServer.getConnectors(); - NetworkConnector connector = (NetworkConnector) connectors[0]; - port = connector.getLocalPort(); - baseUrl = "http://localhost:" + port; - } - - @AfterClass - public static void stopJettyServer() throws Exception { - if (jettyServer != null) { - jettyServer.stop(); - } - } - - - /** Servlet that sets the given status code. */ - @SuppressWarnings("serial") - private static class StatusCodeServlet extends GenericServlet { - - private final int sc; - - public StatusCodeServlet(int sc) { - this.sc = sc; - } - - @Override - public void service(ServletRequest request, ServletResponse response) throws IOException { - ((HttpServletResponse) response).setStatus(sc); - } - } - - - /** Servlet that returns an error message for a given status code. */ - @SuppressWarnings("serial") - private static class ErrorServlet extends GenericServlet { - - private final int sc; - - public ErrorServlet(int sc) { - this.sc = sc; - } - - @Override - public void service(ServletRequest request, ServletResponse response) throws IOException { - ((HttpServletResponse) response).sendError(sc); - } - } - - - @SuppressWarnings("serial") - private static class GetServlet extends HttpServlet { - - private final byte[] buf; - - private final MediaType contentType; - - public GetServlet(byte[] buf, MediaType contentType) { - this.buf = buf; - this.contentType = contentType; - } - - @Override - protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException { - if (contentType != null) { - response.setContentType(contentType.toString()); - } - response.setContentLength(buf.length); - FileCopyUtils.copy(buf, response.getOutputStream()); - } - } - - - @SuppressWarnings("serial") - private static class PostServlet extends HttpServlet { - - private final String content; - - private final String location; - - private final byte[] buf; - - private final MediaType contentType; - - public PostServlet(String content, String location, byte[] buf, MediaType contentType) { - this.content = content; - this.location = location; - this.buf = buf; - this.contentType = contentType; - } - - @Override - protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException { - assertTrue("Invalid request content-length", request.getContentLength() > 0); - assertNotNull("No content-type", request.getContentType()); - String body = FileCopyUtils.copyToString(request.getReader()); - assertEquals("Invalid request body", content, body); - response.setStatus(HttpServletResponse.SC_CREATED); - response.setHeader("Location", baseUrl + location); - response.setContentLength(buf.length); - response.setContentType(contentType.toString()); - FileCopyUtils.copy(buf, response.getOutputStream()); - } - } - - - @SuppressWarnings("serial") - private static class JsonPostServlet extends HttpServlet { - - private final String location; - - private final MediaType contentType; - - public JsonPostServlet(String location, MediaType contentType) { - this.location = location; - this.contentType = contentType; - } - - @Override - protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException { - assertTrue("Invalid request content-length", request.getContentLength() > 0); - assertNotNull("No content-type", request.getContentType()); - String body = FileCopyUtils.copyToString(request.getReader()); - response.setStatus(HttpServletResponse.SC_CREATED); - response.setHeader("Location", baseUrl +location); - response.setContentType(contentType.toString()); - byte[] bytes = body.getBytes("utf-8"); - response.setContentLength(bytes.length);; - FileCopyUtils.copy(bytes, response.getOutputStream()); - } - } - - - @SuppressWarnings("serial") - private static class PutServlet extends HttpServlet { - - private final String s; - - public PutServlet(String s, byte[] buf, MediaType contentType) { - this.s = s; - } - - @Override - protected void doPut(HttpServletRequest request, HttpServletResponse response) throws IOException { - assertTrue("Invalid request content-length", request.getContentLength() > 0); - assertNotNull("No content-type", request.getContentType()); - String body = FileCopyUtils.copyToString(request.getReader()); - assertEquals("Invalid request body", s, body); - response.setStatus(HttpServletResponse.SC_ACCEPTED); - } - } - - - @SuppressWarnings("serial") - private static class UriServlet extends HttpServlet { - - @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException { - resp.setContentType("text/plain"); - resp.setCharacterEncoding("utf-8"); - resp.getWriter().write(req.getRequestURI()); - } - } - - - @SuppressWarnings("serial") - private static class MultipartServlet extends HttpServlet { - - @Override - protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - assertTrue(ServletFileUpload.isMultipartContent(req)); - FileItemFactory factory = new DiskFileItemFactory(); - ServletFileUpload upload = new ServletFileUpload(factory); - try { - List items = upload.parseRequest(req); - assertEquals(4, items.size()); - FileItem item = items.get(0); - assertTrue(item.isFormField()); - assertEquals("name 1", item.getFieldName()); - assertEquals("value 1", item.getString()); - - item = items.get(1); - assertTrue(item.isFormField()); - assertEquals("name 2", item.getFieldName()); - assertEquals("value 2+1", item.getString()); - - item = items.get(2); - assertTrue(item.isFormField()); - assertEquals("name 2", item.getFieldName()); - assertEquals("value 2+2", item.getString()); - - item = items.get(3); - assertFalse(item.isFormField()); - assertEquals("logo", item.getFieldName()); - assertEquals("logo.jpg", item.getName()); - assertEquals("image/jpeg", item.getContentType()); - } - catch (FileUploadException ex) { - throw new ServletException(ex); - } - - } - } - - - @SuppressWarnings("serial") - private static class FormServlet extends HttpServlet { - - @Override - protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException { - assertEquals(MediaType.APPLICATION_FORM_URLENCODED_VALUE, req.getContentType()); - - Map parameters = req.getParameterMap(); - assertEquals(2, parameters.size()); - - String[] values = parameters.get("name 1"); - assertEquals(1, values.length); - assertEquals("value 1", values[0]); - - values = parameters.get("name 2"); - assertEquals(2, values.length); - assertEquals("value 2+1", values[0]); - assertEquals("value 2+2", values[1]); - } - } - - - @SuppressWarnings("serial") - private static class DeleteServlet extends HttpServlet { - - @Override - protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws IOException { - resp.setStatus(200); - } - } - - @SuppressWarnings("serial") - private static class PatchServlet extends GenericServlet { - - private final String content; - - private final byte[] buf; - - private final MediaType contentType; - - public PatchServlet(String content, byte[] buf, MediaType contentType) { - this.content = content; - this.buf = buf; - this.contentType = contentType; - } - - @Override - public void service(ServletRequest req, ServletResponse res) - throws ServletException, IOException { - HttpServletRequest request = (HttpServletRequest) req; - HttpServletResponse response = (HttpServletResponse) res; - assertEquals("PATCH", request.getMethod()); - assertTrue("Invalid request content-length", request.getContentLength() > 0); - assertNotNull("No content-type", request.getContentType()); - String body = FileCopyUtils.copyToString(request.getReader()); - assertEquals("Invalid request body", content, body); - response.setStatus(HttpServletResponse.SC_CREATED); - response.setContentLength(buf.length); - response.setContentType(contentType.toString()); - FileCopyUtils.copy(buf, response.getOutputStream()); - } - } - -} diff --git a/spring-web/src/test/java/org/springframework/web/client/AbstractMockWebServerTestCase.java b/spring-web/src/test/java/org/springframework/web/client/AbstractMockWebServerTestCase.java new file mode 100644 index 00000000000..18493412e14 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/AbstractMockWebServerTestCase.java @@ -0,0 +1,254 @@ +package org.springframework.web.client; + +import java.io.EOFException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Collections; + +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import okio.Buffer; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import org.springframework.http.MediaType; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +/** + * @author Brian Clozel + */ +public class AbstractMockWebServerTestCase { + + protected static final String helloWorld = "H\u00e9llo W\u00f6rld"; + + private MockWebServer server; + + protected int port; + + protected String baseUrl; + + protected static final MediaType textContentType = + new MediaType("text", "plain", Collections.singletonMap("charset", "UTF-8")); + + @Before + public void setUp() throws Exception { + this.server = new MockWebServer(); + this.server.setDispatcher(new TestDispatcher()); + this.server.start(); + this.port = this.server.getPort(); + this.baseUrl = "http://localhost:" + this.port; + } + + @After + public void tearDown() throws Exception { + this.server.shutdown(); + } + + protected class TestDispatcher extends Dispatcher { + @Override + public MockResponse dispatch(RecordedRequest request) throws InterruptedException { + try { + byte[] helloWorldBytes = helloWorld.getBytes(StandardCharsets.UTF_8); + + if (request.getPath().equals("/get")) { + return getRequest(request, helloWorldBytes, textContentType.toString()); + } + else if (request.getPath().equals("/get/nothing")) { + return getRequest(request, new byte[0], textContentType.toString()); + } + else if (request.getPath().equals("/get/nocontenttype")) { + return getRequest(request, helloWorldBytes, null); + } + else if (request.getPath().equals("/post")) { + return postRequest(request, helloWorld, "/post/1", textContentType.toString(), helloWorldBytes); + } + else if (request.getPath().equals("/jsonpost")) { + return jsonPostRequest(request, "/jsonpost/1", "application/json; charset=utf-8"); + } + else if (request.getPath().equals("/status/nocontent")) { + return new MockResponse().setResponseCode(204); + } + else if (request.getPath().equals("/status/notmodified")) { + return new MockResponse().setResponseCode(304); + } + else if (request.getPath().equals("/status/notfound")) { + return new MockResponse().setResponseCode(404); + } + else if (request.getPath().equals("/status/server")) { + return new MockResponse().setResponseCode(500); + } + else if (request.getPath().contains("/uri/")) { + return new MockResponse().setBody(request.getPath()).setHeader("Content-Type", "text/plain"); + } + else if (request.getPath().equals("/multipart")) { + return multipartRequest(request); + } + else if (request.getPath().equals("/form")) { + return formRequest(request); + } + else if (request.getPath().equals("/delete")) { + return new MockResponse().setResponseCode(200); + } + else if (request.getPath().equals("/patch")) { + return patchRequest(request, helloWorld, textContentType.toString(), helloWorldBytes); + } + else if (request.getPath().equals("/put")) { + return putRequest(request, helloWorld); + } + return new MockResponse().setResponseCode(404); + } + catch (Throwable exc) { + return new MockResponse().setResponseCode(500).setBody(exc.toString()); + } + } + } + + + private MockResponse getRequest(RecordedRequest request, byte[] body, String contentType) { + if(request.getMethod().equals("OPTIONS")) { + return new MockResponse().setResponseCode(200).setHeader("Allow", "GET, OPTIONS, HEAD, TRACE"); + } + Buffer buf = new Buffer(); + buf.write(body); + MockResponse response = new MockResponse() + .setHeader("Content-Length", body.length) + .setBody(buf) + .setResponseCode(200); + if (contentType != null) { + response = response.setHeader("Content-Type", contentType); + } + return response; + } + + private MockResponse postRequest(RecordedRequest request, String expectedRequestContent, + String location, String contentType, byte[] responseBody) { + + assertTrue("Invalid request content-length", + Integer.parseInt(request.getHeader("Content-Length")) > 0); + String requestContentType = request.getHeader("Content-Type"); + assertNotNull("No content-type", requestContentType); + Charset charset = StandardCharsets.ISO_8859_1; + if(requestContentType.indexOf("charset=") > -1) { + String charsetName = requestContentType.split("charset=")[1]; + charset = Charset.forName(charsetName); + } + assertEquals("Invalid request body", expectedRequestContent, request.getBody().readString(charset)); + Buffer buf = new Buffer(); + buf.write(responseBody); + return new MockResponse() + .setHeader("Location", baseUrl + location) + .setHeader("Content-Type", contentType) + .setHeader("Content-Length", responseBody.length) + .setBody(buf) + .setResponseCode(201); + } + + private MockResponse jsonPostRequest(RecordedRequest request, String location, String contentType) { + + assertTrue("Invalid request content-length", + Integer.parseInt(request.getHeader("Content-Length")) > 0); + assertNotNull("No content-type", request.getHeader("Content-Type")); + return new MockResponse() + .setHeader("Location", baseUrl + location) + .setHeader("Content-Type", contentType) + .setHeader("Content-Length", request.getBody().size()) + .setBody(request.getBody()) + .setResponseCode(201); + } + + private MockResponse multipartRequest(RecordedRequest request) { + String contentType = request.getHeader("Content-Type"); + assertTrue(contentType.startsWith("multipart/form-data")); + String boundary = contentType.split("boundary=")[1]; + Buffer body = request.getBody(); + try { + assertPart(body, "form-data", boundary, "name 1", "text/plain", "value 1"); + assertPart(body, "form-data", boundary, "name 2", "text/plain", "value 2+1"); + assertPart(body, "form-data", boundary, "name 2", "text/plain", "value 2+2"); + assertFilePart(body, "form-data", boundary, "logo", "logo.jpg", "image/jpeg"); + } + catch (EOFException e) { + throw new RuntimeException(e); + } + return new MockResponse().setResponseCode(200); + } + + private void assertPart(Buffer buffer, String disposition, String boundary, String name, + String contentType, String value) throws EOFException { + + assertTrue(buffer.readUtf8Line().contains("--" + boundary)); + String line = buffer.readUtf8Line(); + assertTrue(line.contains("Content-Disposition: "+ disposition)); + assertTrue(line.contains("name=\""+ name + "\"")); + assertTrue(buffer.readUtf8Line().startsWith("Content-Type: "+contentType)); + assertTrue(buffer.readUtf8Line().equals("Content-Length: " + value.length())); + assertTrue(buffer.readUtf8Line().equals("")); + assertTrue(buffer.readUtf8Line().equals(value)); + } + + private void assertFilePart(Buffer buffer, String disposition, String boundary, String name, + String filename, String contentType) throws EOFException { + + assertTrue(buffer.readUtf8Line().contains("--" + boundary)); + String line = buffer.readUtf8Line(); + assertTrue(line.contains("Content-Disposition: "+ disposition)); + assertTrue(line.contains("name=\""+ name + "\"")); + assertTrue(line.contains("filename=\""+ filename + "\"")); + assertTrue(buffer.readUtf8Line().startsWith("Content-Type: "+contentType)); + assertTrue(buffer.readUtf8Line().startsWith("Content-Length: ")); + assertTrue(buffer.readUtf8Line().equals("")); + assertNotNull(buffer.readUtf8Line()); + } + + private MockResponse formRequest(RecordedRequest request) { + assertEquals("application/x-www-form-urlencoded", request.getHeader("Content-Type")); + String body = request.getBody().readUtf8(); + assertThat(body, Matchers.containsString("name+1=value+1")); + assertThat(body, Matchers.containsString("name+2=value+2%2B1")); + assertThat(body, Matchers.containsString("name+2=value+2%2B2")); + return new MockResponse().setResponseCode(200); + } + + private MockResponse patchRequest(RecordedRequest request, String expectedRequestContent, + String contentType, byte[] responseBody) { + assertEquals("PATCH", request.getMethod()); + assertTrue("Invalid request content-length", + Integer.parseInt(request.getHeader("Content-Length")) > 0); + String requestContentType = request.getHeader("Content-Type"); + assertNotNull("No content-type", requestContentType); + Charset charset = StandardCharsets.ISO_8859_1; + if(requestContentType.indexOf("charset=") > -1) { + String charsetName = requestContentType.split("charset=")[1]; + charset = Charset.forName(charsetName); + } + assertEquals("Invalid request body", expectedRequestContent, request.getBody().readString(charset)); + Buffer buf = new Buffer(); + buf.write(responseBody); + return new MockResponse().setResponseCode(201) + .setHeader("Content-Length", responseBody.length) + .setHeader("Content-Type", contentType) + .setBody(buf); + } + + private MockResponse putRequest(RecordedRequest request, String expectedRequestContent) { + assertTrue("Invalid request content-length", + Integer.parseInt(request.getHeader("Content-Length")) > 0); + String requestContentType = request.getHeader("Content-Type"); + assertNotNull("No content-type", requestContentType); + Charset charset = StandardCharsets.ISO_8859_1; + if(requestContentType.indexOf("charset=") > -1) { + String charsetName = requestContentType.split("charset=")[1]; + charset = Charset.forName(charsetName); + } + assertEquals("Invalid request body", expectedRequestContent, request.getBody().readString(charset)); + return new MockResponse().setResponseCode(202); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/AsyncRestTemplateIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/AsyncRestTemplateIntegrationTests.java index f33e59bf583..c508c7690e0 100644 --- a/spring-web/src/test/java/org/springframework/web/client/AsyncRestTemplateIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/AsyncRestTemplateIntegrationTests.java @@ -18,7 +18,6 @@ package org.springframework.web.client; import java.io.IOException; import java.net.URI; -import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.EnumSet; @@ -44,19 +43,25 @@ import org.springframework.http.client.AsyncClientHttpRequestExecution; import org.springframework.http.client.AsyncClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.HttpComponentsAsyncClientHttpRequestFactory; -import org.springframework.http.client.support.HttpRequestWrapper; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.util.concurrent.ListenableFutureCallback; -import static org.junit.Assert.*; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * @author Arjen Poutsma * @author Sebastien Deleuze */ -public class AsyncRestTemplateIntegrationTests extends AbstractJettyServerTestCase { +public class AsyncRestTemplateIntegrationTests extends AbstractMockWebServerTestCase { private final AsyncRestTemplate template = new AsyncRestTemplate( new HttpComponentsAsyncClientHttpRequestFactory()); @@ -588,7 +593,7 @@ public class AsyncRestTemplateIntegrationTests extends AbstractJettyServerTestCa public void getAndInterceptResponse() throws Exception { RequestInterceptor interceptor = new RequestInterceptor(); template.setInterceptors(Collections.singletonList(interceptor)); - ListenableFuture> future = template.getForEntity("/get", String.class); + ListenableFuture> future = template.getForEntity(baseUrl + "/get", String.class); interceptor.latch.await(5, TimeUnit.SECONDS); assertNotNull(interceptor.response); @@ -601,7 +606,7 @@ public class AsyncRestTemplateIntegrationTests extends AbstractJettyServerTestCa public void getAndInterceptError() throws Exception { RequestInterceptor interceptor = new RequestInterceptor(); template.setInterceptors(Collections.singletonList(interceptor)); - template.getForEntity("/status/notfound", String.class); + template.getForEntity(baseUrl + "/status/notfound", String.class); interceptor.latch.await(5, TimeUnit.SECONDS); assertNotNull(interceptor.response); @@ -627,18 +632,6 @@ public class AsyncRestTemplateIntegrationTests extends AbstractJettyServerTestCa public ListenableFuture intercept(HttpRequest request, byte[] body, AsyncClientHttpRequestExecution execution) throws IOException { - request = new HttpRequestWrapper(request) { - @Override - public URI getURI() { - try { - return new URI(baseUrl + super.getURI().toString()); - } - catch (URISyntaxException ex) { - throw new IllegalStateException(ex); - } - } - }; - ListenableFuture future = execution.executeAsync(request, body); future.addCallback( resp -> { diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java index 140469398d2..93015c18292 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java @@ -21,6 +21,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.EnumSet; import java.util.List; import java.util.Set; @@ -28,7 +29,12 @@ import java.util.Set; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.annotation.JsonView; +import org.hamcrest.Matchers; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.ClassPathResource; @@ -40,7 +46,11 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.Netty4ClientHttpRequestFactory; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.converter.json.MappingJacksonValue; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -55,11 +65,30 @@ import static org.junit.Assert.fail; /** * @author Arjen Poutsma + * @author Brian Clozel */ -public class RestTemplateIntegrationTests extends AbstractJettyServerTestCase { +@RunWith(Parameterized.class) +public class RestTemplateIntegrationTests extends AbstractMockWebServerTestCase { - private final RestTemplate template = new RestTemplate(new HttpComponentsClientHttpRequestFactory()); + private RestTemplate template; + @Parameterized.Parameter + public ClientHttpRequestFactory clientHttpRequestFactory; + + @Parameterized.Parameters + public static Iterable data() { + return Arrays.asList( + new HttpComponentsClientHttpRequestFactory(), + new Netty4ClientHttpRequestFactory(), + new OkHttp3ClientHttpRequestFactory(), + new SimpleClientHttpRequestFactory() + ); + } + + @Before + public void setUpClient() { + this.template = new RestTemplate(this.clientHttpRequestFactory); + } @Test public void getString() { @@ -131,6 +160,9 @@ public class RestTemplateIntegrationTests extends AbstractJettyServerTestCase { @Test public void patchForObject() throws URISyntaxException { + // JDK client does not support the PATCH method + Assume.assumeThat(this.clientHttpRequestFactory, + Matchers.not(Matchers.instanceOf(SimpleClientHttpRequestFactory.class))); String s = template.patchForObject(baseUrl + "/{method}", helloWorld, String.class, "patch"); assertEquals("Invalid content", helloWorld, s); }