From d75a7c38180b217b19bfa890bc4fd4e63f4416b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Mon, 11 Dec 2023 11:22:36 +0100 Subject: [PATCH] Support multiple CoWebFilter changing the context This commit ensures CoWebFilter merges the exchange CoroutineContext with the filter one if needed. Closes gh-31792 --- .../springframework/web/server/CoWebFilter.kt | 4 +- .../web/server/CoWebFilterTests.kt | 64 ++++++++++++++++++- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt b/spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt index 5ff3ecae1c..a708e3c74a 100644 --- a/spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt +++ b/spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt @@ -22,6 +22,7 @@ import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.reactor.awaitSingleOrNull import kotlinx.coroutines.reactor.mono import reactor.core.publisher.Mono +import kotlin.coroutines.CoroutineContext /** * Kotlin-specific implementation of the [WebFilter] interface that allows for @@ -34,7 +35,8 @@ import reactor.core.publisher.Mono abstract class CoWebFilter : WebFilter { final override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono { - return mono(Dispatchers.Unconfined) { + val context = exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext? + return mono(context ?: Dispatchers.Unconfined) { filter(exchange, object : CoWebFilterChain { override suspend fun filter(exchange: ServerWebExchange) { exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key) diff --git a/spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt b/spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt index 872f3a4bc0..e73c6efe6e 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt @@ -26,6 +26,7 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe import org.springframework.web.testfixture.server.MockServerWebExchange import reactor.core.publisher.Mono import reactor.test.StepVerifier +import kotlin.coroutines.AbstractCoroutineContextElement import kotlin.coroutines.CoroutineContext /** @@ -44,12 +45,26 @@ class CoWebFilterTests { val filter = MyCoWebFilter() val result = filter.filter(exchange, chain) - StepVerifier.create(result) - .verifyComplete() + StepVerifier.create(result).verifyComplete() assertThat(exchange.attributes["foo"]).isEqualTo("bar") } + @Test + fun multipleFilters() { + val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com")) + + val chain = Mockito.mock(WebFilterChain::class.java) + given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilter().filter(exchange,chain) }.willReturn(Mono.empty()) + + val result = MyCoWebFilter().filter(exchange, chain) + + StepVerifier.create(result).verifyComplete() + + assertThat(exchange.attributes["foo"]).isEqualTo("bar") + assertThat(exchange.attributes["foofoo"]).isEqualTo("barbar") + } + @Test fun filterWithContext() { val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com")) @@ -69,6 +84,28 @@ class CoWebFilterTests { assertThat(coroutineName.name).isEqualTo("foo") } + @Test + fun multipleFiltersWithContext() { + val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com")) + + val chain = Mockito.mock(WebFilterChain::class.java) + given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilterWithContext().filter(exchange,chain) }.willReturn(Mono.empty()) + + val filter = MyCoWebFilterWithContext() + val result = filter.filter(exchange, chain) + + StepVerifier.create(result).verifyComplete() + + val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext + assertThat(context).isNotNull() + val coroutineName = context[CoroutineName.Key] as CoroutineName + assertThat(coroutineName).isNotNull() + assertThat(coroutineName.name).isEqualTo("foo") + val coroutineDescription = context[CoroutineDescription.Key] as CoroutineDescription + assertThat(coroutineDescription).isNotNull() + assertThat(coroutineDescription.description).isEqualTo("foofoo") + } + } @@ -79,6 +116,13 @@ private class MyCoWebFilter : CoWebFilter() { } } +private class MyOtherCoWebFilter : CoWebFilter() { + override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) { + exchange.attributes["foofoo"] = "barbar" + chain.filter(exchange) + } +} + private class MyCoWebFilterWithContext : CoWebFilter() { override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) { withContext(CoroutineName("foo")) { @@ -86,3 +130,19 @@ private class MyCoWebFilterWithContext : CoWebFilter() { } } } + +private class MyOtherCoWebFilterWithContext : CoWebFilter() { + override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) { + withContext(CoroutineDescription("foofoo")) { + chain.filter(exchange) + } + } +} + +data class CoroutineDescription(val description: String) : AbstractCoroutineContextElement(CoroutineDescription) { + + companion object Key : CoroutineContext.Key + + override fun toString(): String = "CoroutineDescription($description)" +} +