Handle CancellationException in JdkClientHttpRequest

Handle CancellationException in order to throw an HttpTimeoutException
when the timeout handler caused the cancellation.

See gh-34721

Signed-off-by: giampaolo <giampaorr@gmail.com>

fix: use timeoutHandler with a flag isTimeout

    Closes gh-33973

    Signed-off-by: giampaolo <giampaorr@gmail.com>
This commit is contained in:
giampaolo 2025-04-05 14:19:06 +02:00 committed by rstoyanchev
parent 5df9fd4eff
commit 7a55ce48a9
2 changed files with 136 additions and 2 deletions

View File

@ -37,6 +37,7 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@ -97,12 +98,13 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest {
@SuppressWarnings("NullAway")
protected ClientHttpResponse executeInternal(HttpHeaders headers, @Nullable Body body) throws IOException {
CompletableFuture<HttpResponse<InputStream>> responseFuture = null;
TimeoutHandler timeoutHandler = null;
try {
HttpRequest request = buildRequest(headers, body);
responseFuture = this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream());
if (this.timeout != null) {
TimeoutHandler timeoutHandler = new TimeoutHandler(responseFuture, this.timeout);
timeoutHandler = new TimeoutHandler(responseFuture, this.timeout);
HttpResponse<InputStream> response = responseFuture.get();
InputStream inputStream = timeoutHandler.wrapInputStream(response);
return new JdkClientHttpResponse(response, inputStream);
@ -121,7 +123,10 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest {
Throwable cause = ex.getCause();
if (cause instanceof CancellationException) {
throw new HttpTimeoutException("Request timed out");
if (timeoutHandler != null && timeoutHandler.isTimeout()) {
throw new HttpTimeoutException("Request timed out");
}
throw new IOException("Request was cancelled");
}
if (cause instanceof UncheckedIOException uioEx) {
throw uioEx.getCause();
@ -136,6 +141,12 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest {
throw new IOException(cause.getMessage(), cause);
}
}
catch (CancellationException ex) {
if (timeoutHandler != null && timeoutHandler.isTimeout()) {
throw new HttpTimeoutException("Request timed out");
}
throw new IOException("Request was cancelled");
}
}
private HttpRequest buildRequest(HttpHeaders headers, @Nullable Body body) {
@ -233,6 +244,7 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest {
private static final class TimeoutHandler {
private final CompletableFuture<Void> timeoutFuture;
private final AtomicBoolean isTimeout = new AtomicBoolean(false);
private TimeoutHandler(CompletableFuture<HttpResponse<InputStream>> future, Duration timeout) {
@ -241,6 +253,7 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest {
this.timeoutFuture.thenRun(() -> {
if (future.cancel(true) || future.isCompletedExceptionally() || !future.isDone()) {
isTimeout.set(true);
return;
}
try {
@ -268,6 +281,10 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest {
}
};
}
public boolean isTimeout() {
return isTimeout.get();
}
}
}

View File

@ -0,0 +1,117 @@
package org.springframework.http.client;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.HttpTimeoutException;
import java.time.Duration;
import java.util.concurrent.*;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
class JdkClientHttpRequestTest {
private HttpClient mockHttpClient;
private URI uri = URI.create("http://example.com");
private HttpMethod method = HttpMethod.GET;
private ExecutorService executor;
@BeforeEach
void setup() {
mockHttpClient = mock(HttpClient.class);
executor = Executors.newSingleThreadExecutor();
}
@AfterEach
void tearDown() {
executor.shutdownNow();
}
@Test
void executeInternal_withTimeout_shouldThrowHttpTimeoutException() throws Exception {
Duration timeout = Duration.ofMillis(10);
JdkClientHttpRequest request = new JdkClientHttpRequest(mockHttpClient, uri, method, executor, timeout);
CompletableFuture<HttpResponse<InputStream>> future = new CompletableFuture<>();
when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(future);
HttpHeaders headers = new HttpHeaders();
CountDownLatch startLatch = new CountDownLatch(1);
// Cancellation thread waits for startLatch, then cancels the future after a delay
Thread canceller = new Thread(() -> {
try {
startLatch.await();
Thread.sleep(500);
future.cancel(true);
} catch (InterruptedException ignored) {
}
});
canceller.start();
IOException ex = assertThrows(IOException.class, () -> {
startLatch.countDown();
request.executeInternal(headers, null);
});
assertThat(ex)
.isInstanceOf(HttpTimeoutException.class)
.hasMessage("Request timed out");
canceller.join();
}
@Test
void executeInternal_withTimeout_shouldThrowIOException() throws Exception {
Duration timeout = Duration.ofMillis(500);
JdkClientHttpRequest request = new JdkClientHttpRequest(mockHttpClient, uri, method, executor, timeout);
CompletableFuture<HttpResponse<InputStream>> future = new CompletableFuture<>();
when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(future);
HttpHeaders headers = new HttpHeaders();
CountDownLatch startLatch = new CountDownLatch(1);
Thread canceller = new Thread(() -> {
try {
startLatch.await();
Thread.sleep(10);
future.cancel(true);
} catch (InterruptedException ignored) {
}
});
canceller.start();
IOException ex = assertThrows(IOException.class, () -> {
startLatch.countDown();
request.executeInternal(headers, null);
});
assertThat(ex)
.isInstanceOf(IOException.class)
.hasMessage("Request was cancelled");
canceller.join();
}
}