diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientClientHttpRequest.java deleted file mode 100644 index 89b3b42968..0000000000 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientClientHttpRequest.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.http.client; - -import java.io.IOException; -import java.io.InputStream; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.List; - -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; - -/** - * {@link ClientHttpRequest} implementation based on the Java {@code HttpClient}. - * - * @author Marten Deinum - * @since 6.1 - */ -public class JdkClientClientHttpRequest extends AbstractBufferingClientHttpRequest { - - /* - * The JDK HttpRequest doesn't allow all headers to be set. The named headers are taken from the default - * implementation for HttpRequest. - */ - private static final List DISALLOWED_HEADERS = - List.of("connection", "content-length", "expect", "host", "upgrade"); - - private final HttpClient client; - private final URI uri; - private final HttpMethod method; - public JdkClientClientHttpRequest(HttpClient client, URI uri, HttpMethod method) { - this.client = client; - this.uri = uri; - this.method = method; - } - - @Override - public HttpMethod getMethod() { - return this.method; - } - - @Override - public URI getURI() { - return this.uri; - } - - @Override - protected ClientHttpResponse executeInternal(HttpHeaders headers, byte[] content) throws IOException { - - HttpRequest.Builder builder = HttpRequest.newBuilder(this.uri) - .method(getMethod().name(), HttpRequest.BodyPublishers.ofByteArray(content)); - - addHeaders(headers, builder); - HttpRequest request = builder.build(); - HttpResponse response; - try { - response = this.client.send(request, HttpResponse.BodyHandlers.ofInputStream()); - } catch (InterruptedException ex) - { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Request interupted.", ex); - } - return new JdkClientClientHttpResponse(response); - } - - private static void addHeaders(HttpHeaders headers, HttpRequest.Builder builder) { - headers.forEach((headerName, headerValues) -> { - if (!DISALLOWED_HEADERS.contains(headerName.toLowerCase())) { - for (String headerValue : headerValues) { - builder.header(headerName, headerValue); - } - } - }); - } -} diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientClientHttpRequestFactory.java deleted file mode 100644 index 02c2269d80..0000000000 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientClientHttpRequestFactory.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.http.client; - -import java.io.IOException; -import java.net.URI; -import java.net.http.HttpClient; - -import org.springframework.http.HttpMethod; - - -/** - * {@link ClientHttpRequestFactory} implementation that uses a - * HttpClient to create requests. - * - * @author Marten Deinum - * @since 6.1 - */ -public class JdkClientClientHttpRequestFactory implements ClientHttpRequestFactory { - - private HttpClient client; - - private final boolean defaultClient; - - - public JdkClientClientHttpRequestFactory() { - this.client = HttpClient.newHttpClient(); - this.defaultClient = true; - } - - public JdkClientClientHttpRequestFactory(HttpClient client) { - this.client = client; - this.defaultClient = false; - } - - @Override - public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { - return new JdkClientClientHttpRequest(this.client, uri, httpMethod); - } - -} diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java new file mode 100644 index 0000000000..29c1882e59 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java @@ -0,0 +1,134 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Flow; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.StreamUtils; + +/** + * {@link ClientHttpRequest} implementation based the Java {@link HttpClient}. + * Created via the {@link JdkClientHttpRequestFactory}. + * + * @author Marten Deinum + * @author Arjen Poutsma + * @since 6.1 + */ +class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest { + + /* + * The JDK HttpRequest doesn't allow all headers to be set. The named headers are taken from the default + * implementation for HttpRequest. + */ + private static final List DISALLOWED_HEADERS = + List.of("connection", "content-length", "expect", "host", "upgrade"); + + private final HttpClient httpClient; + + private final HttpMethod method; + + private final URI uri; + + private final Executor executor; + + + public JdkClientHttpRequest(HttpClient httpClient, URI uri, HttpMethod method, Executor executor) { + this.httpClient = httpClient; + this.uri = uri; + this.method = method; + this.executor = executor; + } + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + public URI getURI() { + return this.uri; + } + + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers, @Nullable Body body) throws IOException { + try { + HttpRequest request = buildRequest(headers, body); + HttpResponse response = this.httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + return new JdkClientHttpResponse(response); + } + catch (UncheckedIOException ex) { + throw ex.getCause(); + } + catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new IOException("Could not send request: " + ex.getMessage(), ex); + } + } + + + private HttpRequest buildRequest(HttpHeaders headers, @Nullable Body body) { + HttpRequest.Builder builder = HttpRequest.newBuilder() + .uri(this.uri); + + headers.forEach((headerName, headerValues) -> { + if (!headerName.equalsIgnoreCase(HttpHeaders.CONTENT_LENGTH)) { + if (!DISALLOWED_HEADERS.contains(headerName.toLowerCase())) { + for (String headerValue : headerValues) { + builder.header(headerName, headerValue); + } + } + } + }); + + builder.method(this.method.name(), bodyPublisher(headers, body)); + return builder.build(); + } + + private HttpRequest.BodyPublisher bodyPublisher(HttpHeaders headers, @Nullable Body body) { + if (body != null) { + Flow.Publisher outputStreamPublisher = OutputStreamPublisher.create( + outputStream -> body.writeTo(StreamUtils.nonClosing(outputStream)), + this.executor); + + long contentLength = headers.getContentLength(); + if (contentLength != -1) { + return HttpRequest.BodyPublishers.fromPublisher(outputStreamPublisher, contentLength); + } + else { + return HttpRequest.BodyPublishers.fromPublisher(outputStreamPublisher); + } + } + else { + return HttpRequest.BodyPublishers.noBody(); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java new file mode 100644 index 0000000000..853a6728ce --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java @@ -0,0 +1,82 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.util.concurrent.Executor; + +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; + + +/** + * {@link ClientHttpRequestFactory} implementation based on the Java + * {@link HttpClient}. + * + * @author Marten Deinum + * @author Arjen Poutsma + * @since 6.1 + */ +public class JdkClientHttpRequestFactory implements ClientHttpRequestFactory { + + private final HttpClient httpClient; + + private final Executor executor; + + + /** + * Create a new instance of the {@code JdkClientHttpRequestFactory} + * with a default {@link HttpClient}. + */ + public JdkClientHttpRequestFactory() { + this(HttpClient.newHttpClient()); + } + + /** + * Create a new instance of the {@code JdkClientHttpRequestFactory} based on + * the given {@link HttpClient}. + * @param httpClient the client to base on + */ + public JdkClientHttpRequestFactory(HttpClient httpClient) { + Assert.notNull(httpClient, "HttpClient is required"); + this.httpClient = httpClient; + this.executor = httpClient.executor().orElseGet(SimpleAsyncTaskExecutor::new); + } + + /** + * Create a new instance of the {@code JdkClientHttpRequestFactory} based on + * the given {@link HttpClient} and {@link Executor}. + * @param httpClient the client to base on + * @param executor the executor to use for blocking write operations + */ + public JdkClientHttpRequestFactory(HttpClient httpClient, Executor executor) { + Assert.notNull(httpClient, "HttpClient is required"); + Assert.notNull(executor, "Executor must not be null"); + this.httpClient = httpClient; + this.executor = executor; + } + + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + return new JdkClientHttpRequest(this.httpClient, uri, httpMethod, this.executor); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java similarity index 50% rename from spring-web/src/main/java/org/springframework/http/client/JdkClientClientHttpResponse.java rename to spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java index 770f79fb16..5e6a2a84e4 100644 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientClientHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java @@ -18,81 +18,89 @@ package org.springframework.http.client; import java.io.IOException; import java.io.InputStream; +import java.net.http.HttpClient; import java.net.http.HttpResponse; +import java.util.List; +import java.util.Locale; +import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatusCode; -import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.MultiValueMap; import org.springframework.util.StreamUtils; /** - * {@link ClientHttpResponse} implementation based on the Java {@code HttpClient}. + * {@link ClientHttpResponse} implementation based on the Java {@link HttpClient}. * * @author Marten Deinum + * @author Arjen Poutsma * @since 6.1 */ -public class JdkClientClientHttpResponse implements ClientHttpResponse { +class JdkClientHttpResponse implements ClientHttpResponse { private final HttpResponse response; - @Nullable - private volatile HttpHeaders headers; - public JdkClientClientHttpResponse(HttpResponse response) { + private final HttpHeaders headers; + + private final InputStream body; + + + public JdkClientHttpResponse(HttpResponse response) { this.response = response; + this.headers = adaptHeaders(response); + InputStream inputStream = response.body(); + this.body = (inputStream != null) ? inputStream : InputStream.nullInputStream(); } + private static HttpHeaders adaptHeaders(HttpResponse response) { + Map> rawHeaders = response.headers().map(); + Map> map = new LinkedCaseInsensitiveMap<>(rawHeaders.size(), Locale.ENGLISH); + MultiValueMap multiValueMap = CollectionUtils.toMultiValueMap(map); + multiValueMap.putAll(rawHeaders); + return HttpHeaders.readOnlyHttpHeaders(multiValueMap); + } + + @Override - public HttpStatusCode getStatusCode() throws IOException { + public HttpStatusCode getStatusCode() { return HttpStatusCode.valueOf(this.response.statusCode()); } - @Override - @Deprecated - public int getRawStatusCode() { - return this.response.statusCode(); - } - @Override public String getStatusText() { - HttpStatus status = HttpStatus.resolve(this.response.statusCode()); - return (status != null) ? status.getReasonPhrase() : ""; - } - - @Override - public InputStream getBody() throws IOException { - InputStream body = this.response.body(); - return (body != null ? body : InputStream.nullInputStream()); + // HttpResponse does not expose status text + if (getStatusCode() instanceof HttpStatus status) { + return status.getReasonPhrase(); + } + else { + return ""; + } } @Override public HttpHeaders getHeaders() { - HttpHeaders headers = this.headers; - if (headers == null) { - headers = new HttpHeaders(); - for (String headerName : this.response.headers().map().keySet()) { - for (String headerValue : this.response.headers().allValues(headerName)) { - headers.add(headerName, headerValue); - } - } - this.headers = headers; - } - return headers; + return this.headers; + } + + @Override + public InputStream getBody() throws IOException { + return this.body; } @Override public void close() { - InputStream body = this.response.body(); try { try { - StreamUtils.drain(body); + StreamUtils.drain(this.body); } finally { - body.close(); + this.body.close(); } } - catch (IOException ex) { - // Ignore exception on close... + catch (IOException ignored) { } } } diff --git a/spring-web/src/main/java/org/springframework/http/client/OutputStreamPublisher.java b/spring-web/src/main/java/org/springframework/http/client/OutputStreamPublisher.java new file mode 100644 index 0000000000..73652a632e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/OutputStreamPublisher.java @@ -0,0 +1,400 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.BufferedOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.Objects; +import java.util.concurrent.Executor; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.LockSupport; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Bridges between {@link OutputStream} and + * {@link Flow.Publisher Flow.Publisher<ByteBuffer>}. + * + * @author Oleh Dokuka + * @author Arjen Poutsma + * @since 6.1 + * @see #create(OutputStreamHandler, Executor) + */ +final class OutputStreamPublisher implements Flow.Publisher { + + private final OutputStreamHandler outputStreamHandler; + + private final Executor executor; + + + private OutputStreamPublisher(OutputStreamHandler outputStreamHandler, Executor executor) { + this.outputStreamHandler = outputStreamHandler; + this.executor = executor; + } + + + /** + * Creates a new {@code Publisher} based on bytes written to a + * {@code OutputStream}. + *
    + *
  • The parameter {@code outputStreamHandler} is invoked once per + * subscription of the returned {@code Publisher}, when the first + * {@code ByteBuffer} is + * {@linkplain Flow.Subscription#request(long) requested}.
  • + *
  • Each {@link OutputStream#write(byte[], int, int) OutputStream.write()} + * invocation that {@code outputStreamHandler} makes will result in a + * {@linkplain Flow.Subscriber#onNext(Object) published} {@code ByteBuffer} + * if there is {@linkplain Flow.Subscription#request(long) demand}.
  • + *
  • If there is no demand, {@code OutputStream.write()} will block + * until there is.
  • + *
  • If the subscription is {@linkplain Flow.Subscription#cancel() cancelled}, + * {@code OutputStream.write()} will throw a {@code IOException}.
  • + *
  • {@linkplain OutputStream#close() Closing} the {@code OutputStream} + * will result in a {@linkplain Flow.Subscriber#onComplete() complete} signal.
  • + *
  • Any {@code IOException}s thrown from {@code outputStreamHandler} will + * be dispatched to the {@linkplain Flow.Subscriber#onError(Throwable) Subscriber}. + *
+ * @param outputStreamHandler invoked when the first buffer is requested + * @param executor used to invoke the {@code outputStreamHandler} + * @return a {@code Publisher} based on bytes written by + * {@code outputStreamHandler} + */ + public static Flow.Publisher create(OutputStreamHandler outputStreamHandler, Executor executor) { + Assert.notNull(outputStreamHandler, "OutputStreamHandler must not be null"); + Assert.notNull(executor, "Executor must not be null"); + + return new OutputStreamPublisher(outputStreamHandler, executor); + } + + + @Override + public void subscribe(Flow.Subscriber subscriber) { + Objects.requireNonNull(subscriber, "Subscriber must not be null"); + + OutputStreamSubscription subscription = new OutputStreamSubscription(subscriber, this.outputStreamHandler); + subscriber.onSubscribe(subscription); + this.executor.execute(subscription::invokeHandler); + } + + + /** + * Defines the contract for handling the {@code OutputStream} provided by + * the {@code OutputStreamPublisher}. + */ + @FunctionalInterface + public interface OutputStreamHandler { + + /** + * Use the given stream for writing. + *
    + *
  • If the linked subscription has + * {@linkplain Flow.Subscription#request(long) demand}, any + * {@linkplain OutputStream#write(byte[], int, int) written} bytes + * will be {@linkplain Flow.Subscriber#onNext(Object) published} to the + * {@link Flow.Subscriber Subscriber}.
  • + *
  • If there is no demand, any + * {@link OutputStream#write(byte[], int, int) write()} invocations will + * block until there is demand.
  • + *
  • If the linked subscription is + * {@linkplain Flow.Subscription#cancel() cancelled}, + * {@link OutputStream#write(byte[], int, int) write()} invocations will + * result in a {@code IOException}.
  • + *
+ * @param outputStream the stream to write to + * @throws IOException any thrown I/O errors will be dispatched to the + * {@linkplain Flow.Subscriber#onError(Throwable) Subscriber} + */ + void handle(OutputStream outputStream) throws IOException; + + } + + + private static final class OutputStreamSubscription extends OutputStream implements Flow.Subscription { + + static final Object READY = new Object(); + + private final Flow.Subscriber actual; + + private final OutputStreamHandler outputStreamHandler; + + private final AtomicLong requested = new AtomicLong(); + + private final AtomicReference parkedThreadAtomic = new AtomicReference<>(); + + @Nullable + private volatile Throwable error; + + private long produced; + + + public OutputStreamSubscription(Flow.Subscriber actual, + OutputStreamHandler outputStreamHandler) { + this.actual = actual; + this.outputStreamHandler = outputStreamHandler; + } + + @Override + public void write(int b) throws IOException { + checkDemandAndAwaitIfNeeded(); + + ByteBuffer byteBuffer = ByteBuffer.allocate(1); + byteBuffer.put((byte) b); + byteBuffer.flip(); + + this.actual.onNext(byteBuffer); + + this.produced++; + } + + @Override + public void write(byte[] b) throws IOException { + write(b, 0, b.length); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + checkDemandAndAwaitIfNeeded(); + + ByteBuffer byteBuffer = ByteBuffer.allocate(len); + byteBuffer.put(b, off, len); + byteBuffer.flip(); + + this.actual.onNext(byteBuffer); + + this.produced++; + } + + private void checkDemandAndAwaitIfNeeded() throws IOException { + long r = this.requested.get(); + + if (isTerminated(r) || isCancelled(r)) { + throw new IOException("Subscription has been terminated"); + } + + long p = this.produced; + if (p == r) { + if (p > 0) { + r = tryProduce(p); + this.produced = 0; + } + + while (true) { + if (isTerminated(r) || isCancelled(r)) { + throw new IOException("Subscription has been terminated"); + } + + if (r != 0) { + return; + } + + await(); + + r = this.requested.get(); + } + } + } + + private void invokeHandler() { + // assume sync write within try-with-resource block + + // use BufferedOutputStream, so that written bytes are buffered + // before publishing as byte buffer + try (OutputStream outputStream = new BufferedOutputStream(this)) { + this.outputStreamHandler.handle(outputStream); + } + catch (IOException ex) { + long previousState = tryTerminate(); + if (isCancelled(previousState)) { + return; + } + + if (isTerminated(previousState)) { + // failure due to illegal requestN + this.actual.onError(this.error); + return; + } + + this.actual.onError(ex); + return; + } + + long previousState = tryTerminate(); + if (isCancelled(previousState)) { + return; + } + + if (isTerminated(previousState)) { + // failure due to illegal requestN + this.actual.onError(this.error); + return; + } + + this.actual.onComplete(); + } + + + @Override + public void request(long n) { + if (n <= 0) { + this.error = new IllegalArgumentException("request should be a positive number"); + long previousState = tryTerminate(); + + if (isTerminated(previousState) || isCancelled(previousState)) { + return; + } + + if (previousState > 0) { + // error should eventually be observed and propagated + return; + } + + // resume parked thread, so it can observe error and propagate it + resume(); + return; + } + + if (addCap(n) == 0) { + // resume parked thread so it can continue the work + resume(); + } + } + + @Override + public void cancel() { + long previousState = tryCancel(); + if (isCancelled(previousState) || previousState > 0) { + return; + } + + // resume parked thread, so it can be unblocked and close all the resources + resume(); + } + + private void await() { + Thread toUnpark = Thread.currentThread(); + + while (true) { + Object current = this.parkedThreadAtomic.get(); + if (current == READY) { + break; + } + + if (current != null && current != toUnpark) { + throw new IllegalStateException("Only one (Virtual)Thread can await!"); + } + + if (this.parkedThreadAtomic.compareAndSet(null, toUnpark)) { + LockSupport.park(); + // we don't just break here because park() can wake up spuriously + // if we got a proper resume, get() == READY and the loop will quit above + } + } + // clear the resume indicator so that the next await call will park without a resume() + this.parkedThreadAtomic.lazySet(null); + } + + private void resume() { + if (this.parkedThreadAtomic.get() != READY) { + Object old = this.parkedThreadAtomic.getAndSet(READY); + if (old != READY) { + LockSupport.unpark((Thread)old); + } + } + } + + private long tryCancel() { + while (true) { + long r = this.requested.get(); + + if (isCancelled(r)) { + return r; + } + + if (this.requested.compareAndSet(r, Long.MIN_VALUE)) { + return r; + } + } + } + + private long tryTerminate() { + while (true) { + long r = this.requested.get(); + + if (isCancelled(r) || isTerminated(r)) { + return r; + } + + if (this.requested.compareAndSet(r, Long.MIN_VALUE | Long.MAX_VALUE)) { + return r; + } + } + } + + private long tryProduce(long n) { + while (true) { + long current = this.requested.get(); + if (isTerminated(current) || isCancelled(current)) { + return current; + } + if (current == Long.MAX_VALUE) { + return Long.MAX_VALUE; + } + long update = current - n; + if (update < 0L) { + update = 0L; + } + if (this.requested.compareAndSet(current, update)) { + return update; + } + } + } + + private long addCap(long n) { + while (true) { + long r = this.requested.get(); + if (isTerminated(r) || isCancelled(r) || r == Long.MAX_VALUE) { + return r; + } + long u = addCap(r, n); + if (this.requested.compareAndSet(r, u)) { + return r; + } + } + } + + private static boolean isTerminated(long state) { + return state == (Long.MIN_VALUE | Long.MAX_VALUE); + } + + private static boolean isCancelled(long state) { + return state == Long.MIN_VALUE; + } + + private static long addCap(long a, long b) { + long res = a + b; + if (res < 0L) { + return Long.MAX_VALUE; + } + return res; + } + } +} diff --git a/spring-web/src/test/java/org/springframework/http/client/JdkClientClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java similarity index 87% rename from spring-web/src/test/java/org/springframework/http/client/JdkClientClientHttpRequestFactoryTests.java rename to spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java index c4ee18a8ad..f882e2bfa4 100644 --- a/spring-web/src/test/java/org/springframework/http/client/JdkClientClientHttpRequestFactoryTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java @@ -23,11 +23,11 @@ import org.springframework.http.HttpMethod; /** * @author Marten Deinum */ -public class JdkClientClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTests { +public class JdkClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTests { @Override protected ClientHttpRequestFactory createRequestFactory() { - return new JdkClientClientHttpRequestFactory(); + return new JdkClientHttpRequestFactory(); } @Override diff --git a/spring-web/src/test/java/org/springframework/http/client/OutputStreamPublisherTests.java b/spring-web/src/test/java/org/springframework/http/client/OutputStreamPublisherTests.java new file mode 100644 index 0000000000..fc792ec55c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/OutputStreamPublisherTests.java @@ -0,0 +1,182 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.Flow; + +import org.junit.jupiter.api.Test; +import org.reactivestreams.FlowAdapters; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIOException; + +/** + * @author Arjen Poutsma + * @author Oleh Dokuka + */ +class OutputStreamPublisherTests { + + private final Executor executor = Executors.newSingleThreadExecutor(); + + @Test + void basic() { + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + try (Writer writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) { + writer.write("foo"); + writer.write("bar"); + writer.write("baz"); + } + }, this.executor); + Flux flux = toString(flowPublisher); + + StepVerifier.create(flux) + .assertNext(s -> assertThat(s).isEqualTo("foobarbaz")) + .verifyComplete(); + } + + @Test + void flush() { + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + try (Writer writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) { + writer.write("foo"); + writer.flush(); + writer.write("bar"); + writer.flush(); + writer.write("baz"); + writer.flush(); + } + }, this.executor); + Flux flux = toString(flowPublisher); + + StepVerifier.create(flux) + .assertNext(s -> assertThat(s).isEqualTo("foo")) + .assertNext(s -> assertThat(s).isEqualTo("bar")) + .assertNext(s -> assertThat(s).isEqualTo("baz")) + .verifyComplete(); + } + + @Test + void cancel() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + try (Writer writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) { + assertThatIOException() + .isThrownBy(() -> { + writer.write("foo"); + writer.flush(); + writer.write("bar"); + writer.flush(); + }) + .withMessage("Subscription has been terminated"); + latch.countDown(); + } + }, this.executor); + Flux flux = toString(flowPublisher); + + StepVerifier.create(flux, 1) + .assertNext(s -> assertThat(s).isEqualTo("foo")) + .thenCancel() + .verify(); + + latch.await(); + } + + @Test + void closed() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + Writer writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8); + writer.write("foo"); + writer.close(); + assertThatIOException().isThrownBy(() -> writer.write("bar")) + .withMessage("Stream closed"); + latch.countDown(); + }, this.executor); + Flux flux = toString(flowPublisher); + + StepVerifier.create(flux) + .assertNext(s -> assertThat(s).isEqualTo("foo")) + .verifyComplete(); + + latch.await(); + } + + @Test + void negativeRequestN() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + try(Writer writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) { + writer.write("foo"); + writer.flush(); + writer.write("foo"); + writer.flush(); + } + finally { + latch.countDown(); + } + }, this.executor); + Flow.Subscription[] subscriptions = new Flow.Subscription[1]; + Flux flux = toString(a-> flowPublisher.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscriptions[0] = subscription; + a.onSubscribe(subscription); + } + + @Override + public void onNext(ByteBuffer item) { + a.onNext(item); + } + + @Override + public void onError(Throwable throwable) { + a.onError(throwable); + } + + @Override + public void onComplete() { + a.onComplete(); + } + })); + + StepVerifier.create(flux, 1) + .assertNext(s -> assertThat(s).isEqualTo("foo")) + .then(() -> subscriptions[0].request(-1)) + .expectErrorMessage("request should be a positive number") + .verify(); + + latch.await(); + } + + private static Flux toString(Flow.Publisher flowPublisher) { + return Flux.from(FlowAdapters.toPublisher(flowPublisher)) + .map(bb -> StandardCharsets.UTF_8.decode(bb).toString()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java index e7dabe1d15..88b498af01 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java @@ -48,6 +48,7 @@ import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.JdkClientHttpRequestFactory; import org.springframework.http.client.JettyClientHttpRequestFactory; import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; @@ -91,10 +92,11 @@ class RestTemplateIntegrationTests extends AbstractMockWebServerTests { static Stream> clientHttpRequestFactories() { return Stream.of( - named("JDK", new SimpleClientHttpRequestFactory()), + named("JDK HttpURLConnection", new SimpleClientHttpRequestFactory()), named("HttpComponents", new HttpComponentsClientHttpRequestFactory()), named("OkHttp", new OkHttp3ClientHttpRequestFactory()), - named("Jetty", new JettyClientHttpRequestFactory()) + named("Jetty", new JettyClientHttpRequestFactory()), + named("JDK HttpClient", new JdkClientHttpRequestFactory()) ); } @@ -225,7 +227,7 @@ class RestTemplateIntegrationTests extends AbstractMockWebServerTests { @ParameterizedRestTemplateTest void patchForObject(ClientHttpRequestFactory clientHttpRequestFactory) throws Exception { assumeFalse(clientHttpRequestFactory instanceof SimpleClientHttpRequestFactory, - "JDK client does not support the PATCH method"); + "HttpURLConnection does not support the PATCH method"); setUpClient(clientHttpRequestFactory); @@ -254,6 +256,7 @@ class RestTemplateIntegrationTests extends AbstractMockWebServerTests { template.execute(baseUrl + "/status/badrequest", HttpMethod.GET, null, null)) .satisfies(ex -> { assertThat(ex.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST); + assumeFalse(clientHttpRequestFactory instanceof JdkClientHttpRequestFactory, "JDK HttpClient does not expose status text"); assertThat(ex.getMessage()).isEqualTo("400 Client Error: [no body]"); }); }