Support multiple CoWebFilter changing the context
This commit ensures CoWebFilter merges the exchange CoroutineContext with the filter one if needed. Closes gh-31792
This commit is contained in:
parent
e2c2268c39
commit
d75a7c3818
|
@ -22,6 +22,7 @@ import kotlinx.coroutines.currentCoroutineContext
|
|||
import kotlinx.coroutines.reactor.awaitSingleOrNull
|
||||
import kotlinx.coroutines.reactor.mono
|
||||
import reactor.core.publisher.Mono
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
/**
|
||||
* Kotlin-specific implementation of the [WebFilter] interface that allows for
|
||||
|
@ -34,7 +35,8 @@ import reactor.core.publisher.Mono
|
|||
abstract class CoWebFilter : WebFilter {
|
||||
|
||||
final override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {
|
||||
return mono(Dispatchers.Unconfined) {
|
||||
val context = exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext?
|
||||
return mono(context ?: Dispatchers.Unconfined) {
|
||||
filter(exchange, object : CoWebFilterChain {
|
||||
override suspend fun filter(exchange: ServerWebExchange) {
|
||||
exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key)
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe
|
|||
import org.springframework.web.testfixture.server.MockServerWebExchange
|
||||
import reactor.core.publisher.Mono
|
||||
import reactor.test.StepVerifier
|
||||
import kotlin.coroutines.AbstractCoroutineContextElement
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
/**
|
||||
|
@ -44,12 +45,26 @@ class CoWebFilterTests {
|
|||
val filter = MyCoWebFilter()
|
||||
val result = filter.filter(exchange, chain)
|
||||
|
||||
StepVerifier.create(result)
|
||||
.verifyComplete()
|
||||
StepVerifier.create(result).verifyComplete()
|
||||
|
||||
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun multipleFilters() {
|
||||
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
|
||||
|
||||
val chain = Mockito.mock(WebFilterChain::class.java)
|
||||
given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilter().filter(exchange,chain) }.willReturn(Mono.empty())
|
||||
|
||||
val result = MyCoWebFilter().filter(exchange, chain)
|
||||
|
||||
StepVerifier.create(result).verifyComplete()
|
||||
|
||||
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
|
||||
assertThat(exchange.attributes["foofoo"]).isEqualTo("barbar")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun filterWithContext() {
|
||||
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
|
||||
|
@ -69,6 +84,28 @@ class CoWebFilterTests {
|
|||
assertThat(coroutineName.name).isEqualTo("foo")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun multipleFiltersWithContext() {
|
||||
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
|
||||
|
||||
val chain = Mockito.mock(WebFilterChain::class.java)
|
||||
given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilterWithContext().filter(exchange,chain) }.willReturn(Mono.empty())
|
||||
|
||||
val filter = MyCoWebFilterWithContext()
|
||||
val result = filter.filter(exchange, chain)
|
||||
|
||||
StepVerifier.create(result).verifyComplete()
|
||||
|
||||
val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext
|
||||
assertThat(context).isNotNull()
|
||||
val coroutineName = context[CoroutineName.Key] as CoroutineName
|
||||
assertThat(coroutineName).isNotNull()
|
||||
assertThat(coroutineName.name).isEqualTo("foo")
|
||||
val coroutineDescription = context[CoroutineDescription.Key] as CoroutineDescription
|
||||
assertThat(coroutineDescription).isNotNull()
|
||||
assertThat(coroutineDescription.description).isEqualTo("foofoo")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -79,6 +116,13 @@ private class MyCoWebFilter : CoWebFilter() {
|
|||
}
|
||||
}
|
||||
|
||||
private class MyOtherCoWebFilter : CoWebFilter() {
|
||||
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
|
||||
exchange.attributes["foofoo"] = "barbar"
|
||||
chain.filter(exchange)
|
||||
}
|
||||
}
|
||||
|
||||
private class MyCoWebFilterWithContext : CoWebFilter() {
|
||||
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
|
||||
withContext(CoroutineName("foo")) {
|
||||
|
@ -86,3 +130,19 @@ private class MyCoWebFilterWithContext : CoWebFilter() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class MyOtherCoWebFilterWithContext : CoWebFilter() {
|
||||
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
|
||||
withContext(CoroutineDescription("foofoo")) {
|
||||
chain.filter(exchange)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data class CoroutineDescription(val description: String) : AbstractCoroutineContextElement(CoroutineDescription) {
|
||||
|
||||
companion object Key : CoroutineContext.Key<CoroutineDescription>
|
||||
|
||||
override fun toString(): String = "CoroutineDescription($description)"
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue