diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index e59824e329f..9bf267c4214 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -520,6 +520,7 @@ final class DefaultRestClient implements RestClient { ClientHttpResponse clientResponse = null; Observation observation = null; + Observation.Scope observationScope = null; URI uri = null; try { uri = initUri(); @@ -532,6 +533,7 @@ final class DefaultRestClient implements RestClient { observationContext.setUriTemplate((String) attributes.get(URI_TEMPLATE_ATTRIBUTE)); observation = ClientHttpObservationDocumentation.HTTP_CLIENT_EXCHANGES.observation(observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, observationRegistry).start(); + observationScope = observation.openScope(); if (this.body != null) { this.body.writeTo(clientRequest); } @@ -540,11 +542,14 @@ final class DefaultRestClient implements RestClient { } clientResponse = clientRequest.execute(); observationContext.setResponse(clientResponse); - ConvertibleClientHttpResponse convertibleWrapper = new DefaultConvertibleClientHttpResponse(clientResponse, observation); + ConvertibleClientHttpResponse convertibleWrapper = new DefaultConvertibleClientHttpResponse(clientResponse, observation, observationScope); return exchangeFunction.exchange(clientRequest, convertibleWrapper); } catch (IOException ex) { ResourceAccessException resourceAccessException = createResourceAccessException(uri, this.httpMethod, ex); + if (observationScope != null) { + observationScope.close(); + } if (observation != null) { observation.error(resourceAccessException); observation.stop(); @@ -552,6 +557,9 @@ final class DefaultRestClient implements RestClient { throw resourceAccessException; } catch (Throwable error) { + if (observationScope != null) { + observationScope.close(); + } if (observation != null) { observation.error(error); observation.stop(); @@ -561,6 +569,9 @@ final class DefaultRestClient implements RestClient { finally { if (close && clientResponse != null) { clientResponse.close(); + if (observationScope != null) { + observationScope.close(); + } if (observation != null) { observation.stop(); } @@ -771,10 +782,12 @@ final class DefaultRestClient implements RestClient { private final Observation observation; + private final Observation.Scope observationScope; - public DefaultConvertibleClientHttpResponse(ClientHttpResponse delegate, Observation observation) { + public DefaultConvertibleClientHttpResponse(ClientHttpResponse delegate, Observation observation, Observation.Scope observationScope) { this.delegate = delegate; this.observation = observation; + this.observationScope = observationScope; } @@ -815,6 +828,7 @@ final class DefaultRestClient implements RestClient { @Override public void close() { this.delegate.close(); + this.observationScope.close(); this.observation.stop(); } diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientObservationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientObservationTests.java index 98c4b136f9a..8181151af9f 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientObservationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientObservationTests.java @@ -27,15 +27,19 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationHandler; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestExecution; import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.observation.ClientRequestObservationContext; import org.springframework.http.client.observation.ClientRequestObservationConvention; @@ -73,12 +77,15 @@ class RestClientObservationTests { @BeforeEach void setupEach() { - this.client = RestClient.builder() + this.client = createBuilder().build(); + this.observationRegistry.observationConfig().observationHandler(new ContextAssertionObservationHandler()); + } + + RestClient.Builder createBuilder() { + return RestClient.builder() .messageConverters(converters -> converters.add(0, this.converter)) .requestFactory(this.requestFactory) - .observationRegistry(this.observationRegistry) - .build(); - this.observationRegistry.observationConfig().observationHandler(new ContextAssertionObservationHandler()); + .observationRegistry(this.observationRegistry); } @Test @@ -238,6 +245,22 @@ class RestClientObservationTests { assertThatHttpObservation().hasLowCardinalityKeyValue("outcome", "SUCCESS"); } + @Test + void openScopeWithObservation() throws Exception { + this.client = createBuilder().requestInterceptor(new ObservationContextInterceptor(this.observationRegistry)) + .defaultStatusHandler(new ObservationErrorHandler(this.observationRegistry)).build(); + mockSentRequest(GET, "https://example.org"); + mockResponseStatus(HttpStatus.OK); + mockResponseBody("Hello World", MediaType.TEXT_PLAIN); + + client.get().uri("https://example.org").retrieve().toBodilessEntity(); + } + + @AfterEach + void checkAfter() { + assertThat(this.observationRegistry.getCurrentObservationScope()).isNull(); + } + private void mockSentRequest(HttpMethod method, String uri) throws Exception { mockSentRequest(method, uri, new HttpHeaders()); @@ -288,4 +311,38 @@ class RestClientObservationTests { } + static class ObservationContextInterceptor implements ClientHttpRequestInterceptor { + + private final TestObservationRegistry observationRegistry; + + public ObservationContextInterceptor(TestObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + } + + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException { + assertThat(this.observationRegistry.getCurrentObservationScope()).isNotNull(); + return execution.execute(request, body); + } + } + + static class ObservationErrorHandler implements ResponseErrorHandler { + + final TestObservationRegistry observationRegistry; + + ObservationErrorHandler(TestObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + } + + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + return true; + } + + @Override + public void handleError(ClientHttpResponse response) throws IOException { + assertThat(this.observationRegistry.getCurrentObservationScope()).isNotNull(); + } + } + }