diff --git a/spring-webflux/spring-webflux.gradle b/spring-webflux/spring-webflux.gradle index 93c7ee5b529..e05cd2b9461 100644 --- a/spring-webflux/spring-webflux.gradle +++ b/spring-webflux/spring-webflux.gradle @@ -40,6 +40,7 @@ dependencies { testImplementation(testFixtures(project(":spring-web"))) testImplementation("com.fasterxml:aalto-xml") testImplementation("com.squareup.okhttp3:mockwebserver") + testImplementation("io.micrometer:context-propagation") testImplementation("io.micrometer:micrometer-observation-test") testImplementation("io.projectreactor:reactor-test") testImplementation("io.reactivex.rxjava3:rxjava") diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java index 01ea08c5098..443ba3018f9 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -463,8 +463,9 @@ final class DefaultWebClient implements WebClient { ClientRequest request = requestBuilder.build(); observationContext.setUriTemplate((String) request.attribute(URI_TEMPLATE_ATTRIBUTE).orElse(null)); observationContext.setRequest(request); - Mono responseMono = filterFunction.apply(exchangeFunction) - .exchange(request) + final ExchangeFilterFunction finalFilterFunction = filterFunction; + Mono responseMono = Mono.defer( + () -> finalFilterFunction.apply(exchangeFunction).exchange(request)) .checkpoint("Request to " + WebClientUtils.getRequestDescription(request.method(), request.url()) + " [DefaultWebClient]") diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientObservationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientObservationTests.java index 73a35047e14..ad662822f53 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientObservationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientObservationTests.java @@ -27,10 +27,12 @@ import io.micrometer.observation.ObservationHandler; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -63,6 +65,7 @@ class WebClientObservationTests { @BeforeEach void setup() { + Hooks.enableAutomaticContextPropagation(); ClientResponse mockResponse = mock(); when(mockResponse.statusCode()).thenReturn(HttpStatus.OK); when(mockResponse.headers()).thenReturn(new MockClientHeaders()); @@ -74,6 +77,11 @@ class WebClientObservationTests { this.observationRegistry.observationConfig().observationHandler(new HeaderInjectingHandler()); } + @AfterEach + void cleanUp() { + Hooks.disableAutomaticContextPropagation(); + } + @Test void recordsObservationForSuccessfulExchange() { this.builder.build().get().uri("/resource/{id}", 42) @@ -148,6 +156,19 @@ class WebClientObservationTests { verifyAndGetRequest(); } + @Test + void setsCurrentObservationInScope() { + ExchangeFilterFunction assertionFilter = (request, chain) -> { + Observation currentObservation = observationRegistry.getCurrentObservation(); + assertThat(currentObservation).isNotNull(); + assertThat(currentObservation.getContext()).isInstanceOf(ClientRequestObservationContext.class); + return chain.exchange(request); + }; + this.builder.filter(assertionFilter).build().get().uri("/resource/{id}", 42) + .retrieve().bodyToMono(Void.class).block(Duration.ofSeconds(5)); + verifyAndGetRequest(); + } + @Test void recordsObservationWithResponseDetailsWhenFilterFunctionErrors() { ExchangeFilterFunction errorFunction = (req, next) -> next.exchange(req).then(Mono.error(new IllegalStateException()));