Pickup CoroutineContext saved by CoWebFilter in coRouter

Closes gh-31793
This commit is contained in:
Sébastien Deleuze 2023-12-11 19:03:35 +01:00
parent 570074259d
commit aabe4d0b07
2 changed files with 32 additions and 1 deletions

View File

@ -27,6 +27,7 @@ import org.springframework.http.HttpMethod
import org.springframework.http.HttpStatusCode
import org.springframework.http.MediaType
import org.springframework.web.reactive.function.server.RouterFunctions.nest
import org.springframework.web.server.CoWebFilter
import reactor.core.publisher.Mono
import java.net.URI
import kotlin.coroutines.CoroutineContext
@ -731,7 +732,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
) : HandlerFunction<T> {
override fun handle(request: ServerRequest): Mono<T> {
return handle(Dispatchers.Unconfined, request)
val context = request.attributes()[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext?
return handle(context ?: Dispatchers.Unconfined, request)
}
fun handle(context: CoroutineContext, request: ServerRequest) = asMono(request, context) {

View File

@ -25,7 +25,11 @@ import org.springframework.http.HttpHeaders.CONTENT_TYPE
import org.springframework.http.HttpMethod.PATCH
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.*
import org.springframework.web.server.CoWebFilter
import org.springframework.web.server.CoWebFilterChain
import org.springframework.web.server.ServerWebExchange
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpResponse
import org.springframework.web.testfixture.server.MockServerWebExchange
import reactor.test.StepVerifier
@ -204,6 +208,16 @@ class CoRouterFunctionDslTests {
.verifyComplete()
}
@Test
fun webFilterAndContext() {
val strategies = HandlerStrategies.builder().webFilter(MyCoWebFilterWithContext()).build()
val httpHandler = RouterFunctions.toHttpHandler(routerWithoutContext, strategies)
val mockRequest = get("https://example.com/").build()
val mockResponse = MockServerHttpResponse()
StepVerifier.create(httpHandler.handle(mockRequest, mockResponse)).verifyComplete()
assertThat(mockResponse.headers.getFirst("context")).contains("Filter context")
}
@Test
fun multipleContextProviders() {
assertThatIllegalStateException().isThrownBy {
@ -309,6 +323,12 @@ class CoRouterFunctionDslTests {
}
}
private val routerWithoutContext = coRouter {
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
}
private val otherRouter = router {
"/other" {
ok().build()
@ -369,3 +389,12 @@ class CoRouterFunctionDslTests {
@Suppress("UNUSED_PARAMETER")
private suspend fun handle(req: ServerRequest) = ServerResponse.ok().buildAndAwait()
private class MyCoWebFilterWithContext : CoWebFilter() {
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
withContext(CoroutineName("Filter context")) {
chain.filter(exchange)
}
}
}