diff --git a/org.springframework.web/src/main/java/org/springframework/web/http/client/AbstractClientHttpRequest.java b/org.springframework.web/src/main/java/org/springframework/web/http/client/AbstractClientHttpRequest.java new file mode 100644 index 00000000000..85c0326f273 --- /dev/null +++ b/org.springframework.web/src/main/java/org/springframework/web/http/client/AbstractClientHttpRequest.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2009 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.http.client; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import org.springframework.util.Assert; +import org.springframework.web.http.HttpHeaders; + +/** + * Abstract base for {@link ClientHttpRequest} that makes sure that headers and body are not written multiple times. + * + * @author Arjen Poutsma + * @since 3.0 + */ +public abstract class AbstractClientHttpRequest implements ClientHttpRequest { + + private boolean executed = false; + + private final HttpHeaders headers = new HttpHeaders(); + + private final ByteArrayOutputStream bufferedOutput = new ByteArrayOutputStream(); + + public final HttpHeaders getHeaders() { + checkExecuted(); + return headers; + } + + public final OutputStream getBody() throws IOException { + checkExecuted(); + return bufferedOutput; + } + + public final ClientHttpResponse execute() throws IOException { + checkExecuted(); + ClientHttpResponse result = executeInternal(headers, bufferedOutput.toByteArray()); + executed = true; + return result; + } + + /** + * Abstract template method that writes the given headers and content to the HTTP request. + */ + protected abstract ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) + throws IOException; + + private void checkExecuted() { + Assert.state(!executed, "ClientRequest already executed"); + } + +} diff --git a/org.springframework.web/src/main/java/org/springframework/web/http/client/SimpleClientHttpRequest.java b/org.springframework.web/src/main/java/org/springframework/web/http/client/SimpleClientHttpRequest.java index b0c950587da..70e45b0cbd1 100644 --- a/org.springframework.web/src/main/java/org/springframework/web/http/client/SimpleClientHttpRequest.java +++ b/org.springframework.web/src/main/java/org/springframework/web/http/client/SimpleClientHttpRequest.java @@ -17,11 +17,11 @@ package org.springframework.web.http.client; import java.io.IOException; -import java.io.OutputStream; import java.net.HttpURLConnection; import java.util.List; import java.util.Map; +import org.springframework.util.FileCopyUtils; import org.springframework.web.http.HttpHeaders; import org.springframework.web.http.HttpMethod; @@ -33,14 +33,10 @@ import org.springframework.web.http.HttpMethod; * @see SimpleClientHttpRequestFactory#createRequest(java.net.URI, HttpMethod) * @since 3.0 */ -final class SimpleClientHttpRequest implements ClientHttpRequest { +final class SimpleClientHttpRequest extends AbstractClientHttpRequest { private final HttpURLConnection connection; - private final HttpHeaders headers = new HttpHeaders(); - - private boolean headersWritten = false; - SimpleClientHttpRequest(HttpURLConnection connection) { this.connection = connection; } @@ -49,30 +45,17 @@ final class SimpleClientHttpRequest implements ClientHttpRequest { return HttpMethod.valueOf(connection.getRequestMethod()); } - public HttpHeaders getHeaders() { - return headers; - } - - public OutputStream getBody() throws IOException { - writeHeaders(); - return connection.getOutputStream(); - } - - public ClientHttpResponse execute() throws IOException { - writeHeaders(); + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException { + for (Map.Entry> entry : headers.entrySet()) { + String headerName = entry.getKey(); + for (String headerValue : entry.getValue()) { + connection.addRequestProperty(headerName, headerValue); + } + } connection.connect(); + FileCopyUtils.copy(bufferedOutput, connection.getOutputStream()); return new SimpleClientHttpResponse(connection); } - private void writeHeaders() { - if (!headersWritten) { - for (Map.Entry> entry : headers.entrySet()) { - String headerName = entry.getKey(); - for (String headerValue : entry.getValue()) { - connection.addRequestProperty(headerName, headerValue); - } - } - headersWritten = true; - } - } } diff --git a/org.springframework.web/src/main/java/org/springframework/web/http/client/commons/CommonsClientHttpRequest.java b/org.springframework.web/src/main/java/org/springframework/web/http/client/commons/CommonsClientHttpRequest.java index 43e6a1d6d82..f569b20f522 100644 --- a/org.springframework.web/src/main/java/org/springframework/web/http/client/commons/CommonsClientHttpRequest.java +++ b/org.springframework.web/src/main/java/org/springframework/web/http/client/commons/CommonsClientHttpRequest.java @@ -16,9 +16,7 @@ package org.springframework.web.http.client.commons; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.OutputStream; import java.util.List; import java.util.Map; @@ -28,10 +26,9 @@ import org.apache.commons.httpclient.methods.ByteArrayRequestEntity; import org.apache.commons.httpclient.methods.EntityEnclosingMethod; import org.apache.commons.httpclient.methods.RequestEntity; -import org.springframework.util.Assert; import org.springframework.web.http.HttpHeaders; import org.springframework.web.http.HttpMethod; -import org.springframework.web.http.client.ClientHttpRequest; +import org.springframework.web.http.client.AbstractClientHttpRequest; import org.springframework.web.http.client.ClientHttpResponse; /** @@ -41,59 +38,36 @@ import org.springframework.web.http.client.ClientHttpResponse; * @author Arjen Poutsma * @see CommonsClientHttpRequestFactory#createRequest(java.net.URI, HttpMethod) */ -final class CommonsClientHttpRequest implements ClientHttpRequest { +final class CommonsClientHttpRequest extends AbstractClientHttpRequest { private final HttpClient httpClient; private final HttpMethodBase httpMethod; - private final HttpHeaders headers = new HttpHeaders(); - - private boolean headersWritten = false; - - private ByteArrayOutputStream bufferedOutput; - CommonsClientHttpRequest(HttpClient httpClient, HttpMethodBase httpMethod) { this.httpClient = httpClient; this.httpMethod = httpMethod; } - public HttpHeaders getHeaders() { - return headers; - } - - public OutputStream getBody() throws IOException { - writeHeaders(); - Assert.isInstanceOf(EntityEnclosingMethod.class, httpMethod); - this.bufferedOutput = new ByteArrayOutputStream(); - return bufferedOutput; - } - public HttpMethod getMethod() { return HttpMethod.valueOf(httpMethod.getName()); } - public ClientHttpResponse execute() throws IOException { - writeHeaders(); + @Override + public ClientHttpResponse executeInternal(HttpHeaders headers, byte[] output) throws IOException { + for (Map.Entry> entry : headers.entrySet()) { + String headerName = entry.getKey(); + for (String headerValue : entry.getValue()) { + httpMethod.addRequestHeader(headerName, headerValue); + } + } if (httpMethod instanceof EntityEnclosingMethod) { EntityEnclosingMethod entityEnclosingMethod = (EntityEnclosingMethod) httpMethod; - RequestEntity requestEntity = new ByteArrayRequestEntity(bufferedOutput.toByteArray()); + RequestEntity requestEntity = new ByteArrayRequestEntity(output); entityEnclosingMethod.setRequestEntity(requestEntity); } httpClient.executeMethod(httpMethod); return new CommonsClientHttpResponse(httpMethod); } - private void writeHeaders() { - if (!headersWritten) { - for (Map.Entry> entry : headers.entrySet()) { - String headerName = entry.getKey(); - for (String headerValue : entry.getValue()) { - httpMethod.addRequestHeader(headerName, headerValue); - } - } - headersWritten = true; - } - } - -} \ No newline at end of file +} diff --git a/org.springframework.web/src/test/java/org/springframework/web/http/client/AbstractHttpRequestFactoryTestCase.java b/org.springframework.web/src/test/java/org/springframework/web/http/client/AbstractHttpRequestFactoryTestCase.java index 002c86ff8d2..a72c99dbcda 100644 --- a/org.springframework.web/src/test/java/org/springframework/web/http/client/AbstractHttpRequestFactoryTestCase.java +++ b/org.springframework.web/src/test/java/org/springframework/web/http/client/AbstractHttpRequestFactoryTestCase.java @@ -99,6 +99,25 @@ public abstract class AbstractHttpRequestFactoryTestCase { assertTrue("Invalid body", Arrays.equals(body, result)); } + @Test(expected = IllegalStateException.class) + public void multipleWrites() throws Exception { + ClientHttpRequest request = factory.createRequest(new URI("http://localhost:8889/echo"), HttpMethod.POST); + byte[] body = "Hello World".getBytes("UTF-8"); + FileCopyUtils.copy(body, request.getBody()); + request.execute(); + FileCopyUtils.copy(body, request.getBody()); + } + + @Test(expected = IllegalStateException.class) + public void headersAfterExecute() throws Exception { + ClientHttpRequest request = factory.createRequest(new URI("http://localhost:8889/echo"), HttpMethod.POST); + request.getHeaders().add("MyHeader", "value"); + byte[] body = "Hello World".getBytes("UTF-8"); + FileCopyUtils.copy(body, request.getBody()); + request.execute(); + request.getHeaders().add("MyHeader", "value"); + } + /** * Servlet that returns and error message for a given status code. */