Propagate CoroutineContext in coRouter filters
Closes gh-26977
This commit is contained in:
parent
bcf11e8919
commit
d47c7f9552
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,6 +17,8 @@
|
|||
package org.springframework.web.reactive.function.server
|
||||
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.currentCoroutineContext
|
||||
import kotlinx.coroutines.reactor.awaitSingle
|
||||
import kotlinx.coroutines.reactor.mono
|
||||
import org.springframework.core.io.Resource
|
||||
|
@ -24,7 +26,9 @@ import org.springframework.http.HttpMethod
|
|||
import org.springframework.http.HttpStatusCode
|
||||
import org.springframework.http.MediaType
|
||||
import org.springframework.web.reactive.function.server.RouterFunctions.nest
|
||||
import reactor.core.publisher.Mono
|
||||
import java.net.URI
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
/**
|
||||
* Allow to create easily a WebFlux.fn [RouterFunction] with a [Coroutines router Kotlin DSL][CoRouterFunctionDsl].
|
||||
|
@ -532,7 +536,12 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
builder.filter { serverRequest, handlerFunction ->
|
||||
mono(Dispatchers.Unconfined) {
|
||||
filterFunction(serverRequest) { handlerRequest ->
|
||||
handlerFunction.handle(handlerRequest).awaitSingle()
|
||||
if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) {
|
||||
handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle()
|
||||
}
|
||||
else {
|
||||
handlerFunction.handle(handlerRequest).awaitSingle()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -618,11 +627,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
return builder.build()
|
||||
}
|
||||
|
||||
private fun asHandlerFunction(init: suspend (ServerRequest) -> ServerResponse) = HandlerFunction {
|
||||
mono(Dispatchers.Unconfined) {
|
||||
init(it)
|
||||
}
|
||||
}
|
||||
private fun <T : ServerResponse> asHandlerFunction(handler: suspend (ServerRequest) -> T) =
|
||||
CoroutineContextAwareHandlerFunction(handler)
|
||||
|
||||
/**
|
||||
* @see ServerResponse.from
|
||||
|
@ -691,6 +697,21 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
*/
|
||||
fun status(status: Int) = ServerResponse.status(status)
|
||||
|
||||
|
||||
private class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
|
||||
private val handler: suspend (ServerRequest) -> T
|
||||
) : HandlerFunction<T> {
|
||||
|
||||
override fun handle(request: ServerRequest): Mono<T> {
|
||||
return handle(Dispatchers.Unconfined, request)
|
||||
}
|
||||
|
||||
fun handle(context: CoroutineContext, request: ServerRequest) = mono(context) {
|
||||
handler(request)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,17 +16,20 @@
|
|||
|
||||
package org.springframework.web.reactive.function.server
|
||||
|
||||
import kotlinx.coroutines.CoroutineName
|
||||
import kotlinx.coroutines.currentCoroutineContext
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.springframework.core.io.ClassPathResource
|
||||
import org.springframework.http.HttpHeaders.*
|
||||
import org.springframework.http.HttpMethod.*
|
||||
import org.springframework.http.HttpHeaders.ACCEPT
|
||||
import org.springframework.http.HttpHeaders.CONTENT_TYPE
|
||||
import org.springframework.http.HttpMethod.PATCH
|
||||
import org.springframework.http.HttpStatus
|
||||
import org.springframework.http.MediaType.*
|
||||
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
|
||||
import org.springframework.web.testfixture.server.MockServerWebExchange
|
||||
import org.springframework.web.reactive.function.server.AttributesTestVisitor
|
||||
import reactor.test.StepVerifier
|
||||
|
||||
/**
|
||||
|
@ -165,6 +168,17 @@ class CoRouterFunctionDslTests {
|
|||
.verifyComplete()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun filteringWithContext() {
|
||||
val mockRequest = get("https://example.com/").build()
|
||||
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
|
||||
StepVerifier.create(filterRouterWithContext.route(request).flatMap { it.handle(request) })
|
||||
.expectNextMatches { response ->
|
||||
response.headers().getFirst("context")!!.contains("Filter context")
|
||||
}
|
||||
.verifyComplete()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun attributes() {
|
||||
val visitor = AttributesTestVisitor()
|
||||
|
@ -226,6 +240,17 @@ class CoRouterFunctionDslTests {
|
|||
}
|
||||
}
|
||||
|
||||
private val filterRouterWithContext = coRouter {
|
||||
filter { request, next ->
|
||||
withContext(CoroutineName("Filter context")) {
|
||||
next(request)
|
||||
}
|
||||
}
|
||||
GET("/") {
|
||||
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
|
||||
}
|
||||
}
|
||||
|
||||
private val otherRouter = router {
|
||||
"/other" {
|
||||
ok().build()
|
||||
|
|
Loading…
Reference in New Issue