Support multiple CoWebFilter changing the context
This commit ensures CoWebFilter merges the exchange CoroutineContext with the filter one if needed. Closes gh-31792
This commit is contained in:
parent
e2c2268c39
commit
d75a7c3818
|
@ -22,6 +22,7 @@ import kotlinx.coroutines.currentCoroutineContext
|
||||||
import kotlinx.coroutines.reactor.awaitSingleOrNull
|
import kotlinx.coroutines.reactor.awaitSingleOrNull
|
||||||
import kotlinx.coroutines.reactor.mono
|
import kotlinx.coroutines.reactor.mono
|
||||||
import reactor.core.publisher.Mono
|
import reactor.core.publisher.Mono
|
||||||
|
import kotlin.coroutines.CoroutineContext
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Kotlin-specific implementation of the [WebFilter] interface that allows for
|
* Kotlin-specific implementation of the [WebFilter] interface that allows for
|
||||||
|
@ -34,7 +35,8 @@ import reactor.core.publisher.Mono
|
||||||
abstract class CoWebFilter : WebFilter {
|
abstract class CoWebFilter : WebFilter {
|
||||||
|
|
||||||
final override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {
|
final override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {
|
||||||
return mono(Dispatchers.Unconfined) {
|
val context = exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext?
|
||||||
|
return mono(context ?: Dispatchers.Unconfined) {
|
||||||
filter(exchange, object : CoWebFilterChain {
|
filter(exchange, object : CoWebFilterChain {
|
||||||
override suspend fun filter(exchange: ServerWebExchange) {
|
override suspend fun filter(exchange: ServerWebExchange) {
|
||||||
exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key)
|
exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key)
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe
|
||||||
import org.springframework.web.testfixture.server.MockServerWebExchange
|
import org.springframework.web.testfixture.server.MockServerWebExchange
|
||||||
import reactor.core.publisher.Mono
|
import reactor.core.publisher.Mono
|
||||||
import reactor.test.StepVerifier
|
import reactor.test.StepVerifier
|
||||||
|
import kotlin.coroutines.AbstractCoroutineContextElement
|
||||||
import kotlin.coroutines.CoroutineContext
|
import kotlin.coroutines.CoroutineContext
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -44,12 +45,26 @@ class CoWebFilterTests {
|
||||||
val filter = MyCoWebFilter()
|
val filter = MyCoWebFilter()
|
||||||
val result = filter.filter(exchange, chain)
|
val result = filter.filter(exchange, chain)
|
||||||
|
|
||||||
StepVerifier.create(result)
|
StepVerifier.create(result).verifyComplete()
|
||||||
.verifyComplete()
|
|
||||||
|
|
||||||
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
|
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun multipleFilters() {
|
||||||
|
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
|
||||||
|
|
||||||
|
val chain = Mockito.mock(WebFilterChain::class.java)
|
||||||
|
given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilter().filter(exchange,chain) }.willReturn(Mono.empty())
|
||||||
|
|
||||||
|
val result = MyCoWebFilter().filter(exchange, chain)
|
||||||
|
|
||||||
|
StepVerifier.create(result).verifyComplete()
|
||||||
|
|
||||||
|
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
|
||||||
|
assertThat(exchange.attributes["foofoo"]).isEqualTo("barbar")
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun filterWithContext() {
|
fun filterWithContext() {
|
||||||
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
|
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
|
||||||
|
@ -69,6 +84,28 @@ class CoWebFilterTests {
|
||||||
assertThat(coroutineName.name).isEqualTo("foo")
|
assertThat(coroutineName.name).isEqualTo("foo")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun multipleFiltersWithContext() {
|
||||||
|
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
|
||||||
|
|
||||||
|
val chain = Mockito.mock(WebFilterChain::class.java)
|
||||||
|
given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilterWithContext().filter(exchange,chain) }.willReturn(Mono.empty())
|
||||||
|
|
||||||
|
val filter = MyCoWebFilterWithContext()
|
||||||
|
val result = filter.filter(exchange, chain)
|
||||||
|
|
||||||
|
StepVerifier.create(result).verifyComplete()
|
||||||
|
|
||||||
|
val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext
|
||||||
|
assertThat(context).isNotNull()
|
||||||
|
val coroutineName = context[CoroutineName.Key] as CoroutineName
|
||||||
|
assertThat(coroutineName).isNotNull()
|
||||||
|
assertThat(coroutineName.name).isEqualTo("foo")
|
||||||
|
val coroutineDescription = context[CoroutineDescription.Key] as CoroutineDescription
|
||||||
|
assertThat(coroutineDescription).isNotNull()
|
||||||
|
assertThat(coroutineDescription.description).isEqualTo("foofoo")
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,6 +116,13 @@ private class MyCoWebFilter : CoWebFilter() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class MyOtherCoWebFilter : CoWebFilter() {
|
||||||
|
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
|
||||||
|
exchange.attributes["foofoo"] = "barbar"
|
||||||
|
chain.filter(exchange)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private class MyCoWebFilterWithContext : CoWebFilter() {
|
private class MyCoWebFilterWithContext : CoWebFilter() {
|
||||||
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
|
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
|
||||||
withContext(CoroutineName("foo")) {
|
withContext(CoroutineName("foo")) {
|
||||||
|
@ -86,3 +130,19 @@ private class MyCoWebFilterWithContext : CoWebFilter() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class MyOtherCoWebFilterWithContext : CoWebFilter() {
|
||||||
|
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
|
||||||
|
withContext(CoroutineDescription("foofoo")) {
|
||||||
|
chain.filter(exchange)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data class CoroutineDescription(val description: String) : AbstractCoroutineContextElement(CoroutineDescription) {
|
||||||
|
|
||||||
|
companion object Key : CoroutineContext.Key<CoroutineDescription>
|
||||||
|
|
||||||
|
override fun toString(): String = "CoroutineDescription($description)"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue