Inherit parent context in coRouter DSL

This commit also allows context override, as it
is useful for the nested router use case.

Closes gh-31831
This commit is contained in:
Sébastien Deleuze 2023-12-13 15:54:23 +01:00
parent 8d4deca2a6
commit a01c6d57bb
2 changed files with 70 additions and 4 deletions

View File

@ -144,7 +144,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
* @see RouterFunctions.nest
*/
fun RequestPredicate.nest(r: (CoRouterFunctionDsl.() -> Unit)) {
builder.add(nest(this, CoRouterFunctionDsl(r).build()))
builder.add(nest(this, CoRouterFunctionDsl(r).also { it.contextProvider = contextProvider }.build()))
}
@ -628,9 +628,6 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
* @since 6.1
*/
fun context(provider: suspend (ServerRequest) -> CoroutineContext) {
if (this.contextProvider != null) {
throw IllegalStateException("The Coroutine context provider should not be defined more than once")
}
this.contextProvider = provider
}

View File

@ -193,6 +193,45 @@ class CoRouterFunctionDslTests {
.verifyComplete()
}
@Test
fun nestedContextProvider() {
val mockRequest = get("https://example.com/nested/")
.header("Custom-Header", "foo")
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(nestedRouterWithContextProvider.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("foo")
}
.verifyComplete()
}
@Test
fun nestedContextProviderWithOverride() {
val mockRequest = get("https://example.com/nested/")
.header("Custom-Header", "foo")
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(nestedRouterWithContextProviderOverride.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("foo")
}
.verifyComplete()
}
@Test
fun doubleNestedContextProvider() {
val mockRequest = get("https://example.com/nested/nested/")
.header("Custom-Header", "foo")
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(nestedRouterWithContextProvider.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("foo")
}
.verifyComplete()
}
@Test
fun contextProviderAndFilter() {
val mockRequest = get("https://example.com/")
@ -323,6 +362,36 @@ class CoRouterFunctionDslTests {
}
}
private val nestedRouterWithContextProvider = coRouter {
context {
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
}
"/nested".nest {
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
"/nested".nest {
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
}
}
}
private val nestedRouterWithContextProviderOverride = coRouter {
context {
CoroutineName("parent-context")
}
"/nested".nest {
context {
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
}
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
}
}
private val routerWithoutContext = coRouter {
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()