Add context function to CoRouterFunctionDsl
This new function allows to customize the CoroutineContext potentially dynamically based on the incoming ServerRequest. Closes gh-27010
This commit is contained in:
parent
64ff37f42c
commit
38392233ba
|
@ -21,6 +21,7 @@ import kotlinx.coroutines.Job
|
||||||
import kotlinx.coroutines.currentCoroutineContext
|
import kotlinx.coroutines.currentCoroutineContext
|
||||||
import kotlinx.coroutines.reactor.awaitSingle
|
import kotlinx.coroutines.reactor.awaitSingle
|
||||||
import kotlinx.coroutines.reactor.mono
|
import kotlinx.coroutines.reactor.mono
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
import org.springframework.core.io.Resource
|
import org.springframework.core.io.Resource
|
||||||
import org.springframework.http.HttpMethod
|
import org.springframework.http.HttpMethod
|
||||||
import org.springframework.http.HttpStatusCode
|
import org.springframework.http.HttpStatusCode
|
||||||
|
@ -72,6 +73,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal val builder = RouterFunctions.route()
|
internal val builder = RouterFunctions.route()
|
||||||
|
|
||||||
|
private var contextProvider: (suspend (ServerRequest) -> CoroutineContext)? = null
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return a composed request predicate that tests against both this predicate AND
|
* Return a composed request predicate that tests against both this predicate AND
|
||||||
* the [other] predicate (String processed as a path predicate). When evaluating the
|
* the [other] predicate (String processed as a path predicate). When evaluating the
|
||||||
|
@ -510,9 +513,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
||||||
*/
|
*/
|
||||||
fun resources(lookupFunction: suspend (ServerRequest) -> Resource?) {
|
fun resources(lookupFunction: suspend (ServerRequest) -> Resource?) {
|
||||||
builder.resources {
|
builder.resources {
|
||||||
mono(Dispatchers.Unconfined) {
|
asMono(it, handler = lookupFunction)
|
||||||
lookupFunction.invoke(it)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -534,7 +535,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
||||||
*/
|
*/
|
||||||
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
|
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
|
||||||
builder.filter { serverRequest, handlerFunction ->
|
builder.filter { serverRequest, handlerFunction ->
|
||||||
mono(Dispatchers.Unconfined) {
|
asMono(serverRequest) {
|
||||||
filterFunction(serverRequest) { handlerRequest ->
|
filterFunction(serverRequest) { handlerRequest ->
|
||||||
if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) {
|
if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) {
|
||||||
handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle()
|
handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle()
|
||||||
|
@ -578,7 +579,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
||||||
*/
|
*/
|
||||||
fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
|
fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
|
||||||
builder.onError(predicate) { throwable, request ->
|
builder.onError(predicate) { throwable, request ->
|
||||||
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
|
asMono(request) { responseProvider.invoke(throwable, request) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -591,7 +592,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
||||||
*/
|
*/
|
||||||
inline fun <reified E : Throwable> onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
|
inline fun <reified E : Throwable> onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
|
||||||
builder.onError({it is E}) { throwable, request ->
|
builder.onError({it is E}) { throwable, request ->
|
||||||
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
|
asMono(request) { responseProvider.invoke(throwable, request) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -619,6 +620,19 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
||||||
builder.withAttributes(attributesConsumer)
|
builder.withAttributes(attributesConsumer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allow to provide the default [CoroutineContext], potentially dynamically based on
|
||||||
|
* the incoming [ServerRequest].
|
||||||
|
* @param provider the [CoroutineContext] provider
|
||||||
|
* @since 6.1.0
|
||||||
|
*/
|
||||||
|
fun context(provider: suspend (ServerRequest) -> CoroutineContext) {
|
||||||
|
if (this.contextProvider != null) {
|
||||||
|
throw IllegalStateException("The Coroutine context provider should be defined not more than once")
|
||||||
|
}
|
||||||
|
this.contextProvider = provider
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return a composed routing function created from all the registered routes.
|
* Return a composed routing function created from all the registered routes.
|
||||||
*/
|
*/
|
||||||
|
@ -627,8 +641,22 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
||||||
return builder.build()
|
return builder.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun <T : ServerResponse> asHandlerFunction(handler: suspend (ServerRequest) -> T) =
|
@PublishedApi
|
||||||
CoroutineContextAwareHandlerFunction(handler)
|
internal fun <T> asMono(request: ServerRequest, context: CoroutineContext = Dispatchers.Unconfined, handler: suspend (ServerRequest) -> T): Mono<T> {
|
||||||
|
return mono(context) {
|
||||||
|
contextProvider?.let {
|
||||||
|
withContext(it.invoke(request)) {
|
||||||
|
handler.invoke(request)
|
||||||
|
}
|
||||||
|
} ?: run {
|
||||||
|
handler.invoke(request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun asHandlerFunction(handler: suspend (ServerRequest) -> ServerResponse) = CoroutineContextAwareHandlerFunction { request ->
|
||||||
|
handler.invoke(request)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @see ServerResponse.from
|
* @see ServerResponse.from
|
||||||
|
@ -698,7 +726,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
||||||
fun status(status: Int) = ServerResponse.status(status)
|
fun status(status: Int) = ServerResponse.status(status)
|
||||||
|
|
||||||
|
|
||||||
private class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
|
private inner class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
|
||||||
private val handler: suspend (ServerRequest) -> T
|
private val handler: suspend (ServerRequest) -> T
|
||||||
) : HandlerFunction<T> {
|
) : HandlerFunction<T> {
|
||||||
|
|
||||||
|
@ -706,7 +734,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
||||||
return handle(Dispatchers.Unconfined, request)
|
return handle(Dispatchers.Unconfined, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun handle(context: CoroutineContext, request: ServerRequest) = mono(context) {
|
fun handle(context: CoroutineContext, request: ServerRequest) = asMono(request, context) {
|
||||||
handler(request)
|
handler(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,11 +16,8 @@
|
||||||
|
|
||||||
package org.springframework.web.reactive.function.server
|
package org.springframework.web.reactive.function.server
|
||||||
|
|
||||||
import kotlinx.coroutines.CoroutineName
|
import kotlinx.coroutines.*
|
||||||
import kotlinx.coroutines.currentCoroutineContext
|
import org.assertj.core.api.Assertions.*
|
||||||
import kotlinx.coroutines.withContext
|
|
||||||
import org.assertj.core.api.Assertions.assertThat
|
|
||||||
import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import org.springframework.core.io.ClassPathResource
|
import org.springframework.core.io.ClassPathResource
|
||||||
import org.springframework.http.HttpHeaders.ACCEPT
|
import org.springframework.http.HttpHeaders.ACCEPT
|
||||||
|
@ -179,6 +176,48 @@ class CoRouterFunctionDslTests {
|
||||||
.verifyComplete()
|
.verifyComplete()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun contextProvider() {
|
||||||
|
val mockRequest = get("https://example.com/")
|
||||||
|
.header("Custom-Header", "foo")
|
||||||
|
.build()
|
||||||
|
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
|
||||||
|
StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) })
|
||||||
|
.expectNextMatches { response ->
|
||||||
|
response.headers().getFirst("context")!!.contains("foo")
|
||||||
|
}
|
||||||
|
.verifyComplete()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun contextProviderAndFilter() {
|
||||||
|
val mockRequest = get("https://example.com/")
|
||||||
|
.header("Custom-Header", "bar")
|
||||||
|
.build()
|
||||||
|
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
|
||||||
|
StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) })
|
||||||
|
.expectNextMatches { response ->
|
||||||
|
response.headers().getFirst("context")!!.let {
|
||||||
|
it.contains("bar") && it.contains("Dispatchers.Default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.verifyComplete()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun multipleContextProviders() {
|
||||||
|
assertThatIllegalStateException().isThrownBy {
|
||||||
|
coRouter {
|
||||||
|
context {
|
||||||
|
CoroutineName("foo")
|
||||||
|
}
|
||||||
|
context {
|
||||||
|
Dispatchers.Default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun attributes() {
|
fun attributes() {
|
||||||
val visitor = AttributesTestVisitor()
|
val visitor = AttributesTestVisitor()
|
||||||
|
@ -251,6 +290,25 @@ class CoRouterFunctionDslTests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private val routerWithContextProvider = coRouter {
|
||||||
|
context {
|
||||||
|
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
|
||||||
|
}
|
||||||
|
GET("/") {
|
||||||
|
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
|
||||||
|
}
|
||||||
|
filter { request, next ->
|
||||||
|
if (request.headers().firstHeader("Custom-Header") == "bar") {
|
||||||
|
withContext(currentCoroutineContext() + Dispatchers.Default) {
|
||||||
|
next.invoke(request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
next.invoke(request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private val otherRouter = router {
|
private val otherRouter = router {
|
||||||
"/other" {
|
"/other" {
|
||||||
ok().build()
|
ok().build()
|
||||||
|
|
Loading…
Reference in New Issue