Support multiple CoWebFilter changing the context

This commit ensures CoWebFilter merges the exchange
CoroutineContext with the filter one if needed.

Closes gh-31792
This commit is contained in:
Sébastien Deleuze 2023-12-11 11:22:36 +01:00
parent e2c2268c39
commit d75a7c3818
2 changed files with 65 additions and 3 deletions

View File

@ -22,6 +22,7 @@ import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.reactor.awaitSingleOrNull
import kotlinx.coroutines.reactor.mono
import reactor.core.publisher.Mono
import kotlin.coroutines.CoroutineContext
/**
* Kotlin-specific implementation of the [WebFilter] interface that allows for
@ -34,7 +35,8 @@ import reactor.core.publisher.Mono
abstract class CoWebFilter : WebFilter {
final override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {
return mono(Dispatchers.Unconfined) {
val context = exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext?
return mono(context ?: Dispatchers.Unconfined) {
filter(exchange, object : CoWebFilterChain {
override suspend fun filter(exchange: ServerWebExchange) {
exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key)

View File

@ -26,6 +26,7 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe
import org.springframework.web.testfixture.server.MockServerWebExchange
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
/**
@ -44,12 +45,26 @@ class CoWebFilterTests {
val filter = MyCoWebFilter()
val result = filter.filter(exchange, chain)
StepVerifier.create(result)
.verifyComplete()
StepVerifier.create(result).verifyComplete()
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
}
@Test
fun multipleFilters() {
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
val chain = Mockito.mock(WebFilterChain::class.java)
given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilter().filter(exchange,chain) }.willReturn(Mono.empty())
val result = MyCoWebFilter().filter(exchange, chain)
StepVerifier.create(result).verifyComplete()
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
assertThat(exchange.attributes["foofoo"]).isEqualTo("barbar")
}
@Test
fun filterWithContext() {
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
@ -69,6 +84,28 @@ class CoWebFilterTests {
assertThat(coroutineName.name).isEqualTo("foo")
}
@Test
fun multipleFiltersWithContext() {
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
val chain = Mockito.mock(WebFilterChain::class.java)
given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilterWithContext().filter(exchange,chain) }.willReturn(Mono.empty())
val filter = MyCoWebFilterWithContext()
val result = filter.filter(exchange, chain)
StepVerifier.create(result).verifyComplete()
val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext
assertThat(context).isNotNull()
val coroutineName = context[CoroutineName.Key] as CoroutineName
assertThat(coroutineName).isNotNull()
assertThat(coroutineName.name).isEqualTo("foo")
val coroutineDescription = context[CoroutineDescription.Key] as CoroutineDescription
assertThat(coroutineDescription).isNotNull()
assertThat(coroutineDescription.description).isEqualTo("foofoo")
}
}
@ -79,6 +116,13 @@ private class MyCoWebFilter : CoWebFilter() {
}
}
private class MyOtherCoWebFilter : CoWebFilter() {
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
exchange.attributes["foofoo"] = "barbar"
chain.filter(exchange)
}
}
private class MyCoWebFilterWithContext : CoWebFilter() {
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
withContext(CoroutineName("foo")) {
@ -86,3 +130,19 @@ private class MyCoWebFilterWithContext : CoWebFilter() {
}
}
}
private class MyOtherCoWebFilterWithContext : CoWebFilter() {
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
withContext(CoroutineDescription("foofoo")) {
chain.filter(exchange)
}
}
}
data class CoroutineDescription(val description: String) : AbstractCoroutineContextElement(CoroutineDescription) {
companion object Key : CoroutineContext.Key<CoroutineDescription>
override fun toString(): String = "CoroutineDescription($description)"
}