Add context function to CoRouterFunctionDsl
This new function allows to customize the CoroutineContext potentially dynamically based on the incoming ServerRequest. Closes gh-27010
This commit is contained in:
parent
64ff37f42c
commit
38392233ba
|
@ -21,6 +21,7 @@ import kotlinx.coroutines.Job
|
|||
import kotlinx.coroutines.currentCoroutineContext
|
||||
import kotlinx.coroutines.reactor.awaitSingle
|
||||
import kotlinx.coroutines.reactor.mono
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.springframework.core.io.Resource
|
||||
import org.springframework.http.HttpMethod
|
||||
import org.springframework.http.HttpStatusCode
|
||||
|
@ -72,6 +73,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
@PublishedApi
|
||||
internal val builder = RouterFunctions.route()
|
||||
|
||||
private var contextProvider: (suspend (ServerRequest) -> CoroutineContext)? = null
|
||||
|
||||
/**
|
||||
* Return a composed request predicate that tests against both this predicate AND
|
||||
* the [other] predicate (String processed as a path predicate). When evaluating the
|
||||
|
@ -510,9 +513,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
*/
|
||||
fun resources(lookupFunction: suspend (ServerRequest) -> Resource?) {
|
||||
builder.resources {
|
||||
mono(Dispatchers.Unconfined) {
|
||||
lookupFunction.invoke(it)
|
||||
}
|
||||
asMono(it, handler = lookupFunction)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -534,7 +535,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
*/
|
||||
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
|
||||
builder.filter { serverRequest, handlerFunction ->
|
||||
mono(Dispatchers.Unconfined) {
|
||||
asMono(serverRequest) {
|
||||
filterFunction(serverRequest) { handlerRequest ->
|
||||
if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) {
|
||||
handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle()
|
||||
|
@ -578,7 +579,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
*/
|
||||
fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
|
||||
builder.onError(predicate) { throwable, request ->
|
||||
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
|
||||
asMono(request) { responseProvider.invoke(throwable, request) }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -591,7 +592,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
*/
|
||||
inline fun <reified E : Throwable> onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
|
||||
builder.onError({it is E}) { throwable, request ->
|
||||
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
|
||||
asMono(request) { responseProvider.invoke(throwable, request) }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -619,6 +620,19 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
builder.withAttributes(attributesConsumer)
|
||||
}
|
||||
|
||||
/**
|
||||
* Allow to provide the default [CoroutineContext], potentially dynamically based on
|
||||
* the incoming [ServerRequest].
|
||||
* @param provider the [CoroutineContext] provider
|
||||
* @since 6.1.0
|
||||
*/
|
||||
fun context(provider: suspend (ServerRequest) -> CoroutineContext) {
|
||||
if (this.contextProvider != null) {
|
||||
throw IllegalStateException("The Coroutine context provider should be defined not more than once")
|
||||
}
|
||||
this.contextProvider = provider
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a composed routing function created from all the registered routes.
|
||||
*/
|
||||
|
@ -627,8 +641,22 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
return builder.build()
|
||||
}
|
||||
|
||||
private fun <T : ServerResponse> asHandlerFunction(handler: suspend (ServerRequest) -> T) =
|
||||
CoroutineContextAwareHandlerFunction(handler)
|
||||
@PublishedApi
|
||||
internal fun <T> asMono(request: ServerRequest, context: CoroutineContext = Dispatchers.Unconfined, handler: suspend (ServerRequest) -> T): Mono<T> {
|
||||
return mono(context) {
|
||||
contextProvider?.let {
|
||||
withContext(it.invoke(request)) {
|
||||
handler.invoke(request)
|
||||
}
|
||||
} ?: run {
|
||||
handler.invoke(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun asHandlerFunction(handler: suspend (ServerRequest) -> ServerResponse) = CoroutineContextAwareHandlerFunction { request ->
|
||||
handler.invoke(request)
|
||||
}
|
||||
|
||||
/**
|
||||
* @see ServerResponse.from
|
||||
|
@ -698,7 +726,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
fun status(status: Int) = ServerResponse.status(status)
|
||||
|
||||
|
||||
private class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
|
||||
private inner class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
|
||||
private val handler: suspend (ServerRequest) -> T
|
||||
) : HandlerFunction<T> {
|
||||
|
||||
|
@ -706,7 +734,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
|||
return handle(Dispatchers.Unconfined, request)
|
||||
}
|
||||
|
||||
fun handle(context: CoroutineContext, request: ServerRequest) = mono(context) {
|
||||
fun handle(context: CoroutineContext, request: ServerRequest) = asMono(request, context) {
|
||||
handler(request)
|
||||
}
|
||||
|
||||
|
|
|
@ -16,11 +16,8 @@
|
|||
|
||||
package org.springframework.web.reactive.function.server
|
||||
|
||||
import kotlinx.coroutines.CoroutineName
|
||||
import kotlinx.coroutines.currentCoroutineContext
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
||||
import kotlinx.coroutines.*
|
||||
import org.assertj.core.api.Assertions.*
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.springframework.core.io.ClassPathResource
|
||||
import org.springframework.http.HttpHeaders.ACCEPT
|
||||
|
@ -179,6 +176,48 @@ class CoRouterFunctionDslTests {
|
|||
.verifyComplete()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun contextProvider() {
|
||||
val mockRequest = get("https://example.com/")
|
||||
.header("Custom-Header", "foo")
|
||||
.build()
|
||||
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
|
||||
StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) })
|
||||
.expectNextMatches { response ->
|
||||
response.headers().getFirst("context")!!.contains("foo")
|
||||
}
|
||||
.verifyComplete()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun contextProviderAndFilter() {
|
||||
val mockRequest = get("https://example.com/")
|
||||
.header("Custom-Header", "bar")
|
||||
.build()
|
||||
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
|
||||
StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) })
|
||||
.expectNextMatches { response ->
|
||||
response.headers().getFirst("context")!!.let {
|
||||
it.contains("bar") && it.contains("Dispatchers.Default")
|
||||
}
|
||||
}
|
||||
.verifyComplete()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun multipleContextProviders() {
|
||||
assertThatIllegalStateException().isThrownBy {
|
||||
coRouter {
|
||||
context {
|
||||
CoroutineName("foo")
|
||||
}
|
||||
context {
|
||||
Dispatchers.Default
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun attributes() {
|
||||
val visitor = AttributesTestVisitor()
|
||||
|
@ -251,6 +290,25 @@ class CoRouterFunctionDslTests {
|
|||
}
|
||||
}
|
||||
|
||||
private val routerWithContextProvider = coRouter {
|
||||
context {
|
||||
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
|
||||
}
|
||||
GET("/") {
|
||||
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
|
||||
}
|
||||
filter { request, next ->
|
||||
if (request.headers().firstHeader("Custom-Header") == "bar") {
|
||||
withContext(currentCoroutineContext() + Dispatchers.Default) {
|
||||
next.invoke(request)
|
||||
}
|
||||
}
|
||||
else {
|
||||
next.invoke(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val otherRouter = router {
|
||||
"/other" {
|
||||
ok().build()
|
||||
|
|
Loading…
Reference in New Issue