Fix parameter bug of handler inside the filterFunction DSL
Co-authored-by: hojongs <hojong.jjh@gmail.com> Co-authored-by: bjh970913 <bjh970913@gmail.com> Closes gh-26838
This commit is contained in:
parent
0468ef46ac
commit
07ba95739b
|
|
@ -531,8 +531,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
|
||||||
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
|
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
|
||||||
builder.filter { serverRequest, handlerFunction ->
|
builder.filter { serverRequest, handlerFunction ->
|
||||||
mono(Dispatchers.Unconfined) {
|
mono(Dispatchers.Unconfined) {
|
||||||
filterFunction(serverRequest) {
|
filterFunction(serverRequest) { handlerRequest ->
|
||||||
handlerFunction.handle(serverRequest).awaitSingle()
|
handlerFunction.handle(handlerRequest).awaitSingle()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -152,6 +152,16 @@ class CoRouterFunctionDslTests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun filtering() {
|
||||||
|
val mockRequest = get("https://example.com/filter").build()
|
||||||
|
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
|
||||||
|
StepVerifier.create(sampleRouter().route(request).flatMap { it.handle(request) })
|
||||||
|
.expectNextMatches { response ->
|
||||||
|
response.headers().getFirst("foo") == "bar"
|
||||||
|
}
|
||||||
|
.verifyComplete()
|
||||||
|
}
|
||||||
|
|
||||||
private fun sampleRouter() = coRouter {
|
private fun sampleRouter() = coRouter {
|
||||||
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
|
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
|
||||||
|
|
@ -186,6 +196,18 @@ class CoRouterFunctionDslTests {
|
||||||
path("/baz", ::handle)
|
path("/baz", ::handle)
|
||||||
GET("/rendering") { RenderingResponse.create("index").buildAndAwait() }
|
GET("/rendering") { RenderingResponse.create("index").buildAndAwait() }
|
||||||
add(otherRouter)
|
add(otherRouter)
|
||||||
|
add(filterRouter)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val filterRouter = coRouter {
|
||||||
|
"/filter" { request ->
|
||||||
|
ok().header("foo", request.headers().firstHeader("foo")).buildAndAwait()
|
||||||
|
}
|
||||||
|
|
||||||
|
filter { request, next ->
|
||||||
|
val newRequest = ServerRequest.from(request).apply { header("foo", "bar") }.build()
|
||||||
|
next(newRequest)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private val otherRouter = router {
|
private val otherRouter = router {
|
||||||
|
|
|
||||||
|
|
@ -649,8 +649,8 @@ class RouterFunctionDsl internal constructor (private val init: (RouterFunctionD
|
||||||
*/
|
*/
|
||||||
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> ServerResponse) -> ServerResponse) {
|
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> ServerResponse) -> ServerResponse) {
|
||||||
builder.filter { request, next ->
|
builder.filter { request, next ->
|
||||||
filterFunction(request) {
|
filterFunction(request) { handlerRequest ->
|
||||||
next.handle(request)
|
next.handle(handlerRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -127,6 +127,13 @@ class RouterFunctionDslTests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun filtering() {
|
||||||
|
val servletRequest = PathPatternsTestUtils.initRequest("GET", "/filter", true)
|
||||||
|
val request = DefaultServerRequest(servletRequest, emptyList())
|
||||||
|
assertThat(sampleRouter().route(request).get().handle(request).headers().getFirst("foo")).isEqualTo("bar")
|
||||||
|
}
|
||||||
|
|
||||||
private fun sampleRouter() = router {
|
private fun sampleRouter() = router {
|
||||||
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
|
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
|
||||||
"/api".nest {
|
"/api".nest {
|
||||||
|
|
@ -160,6 +167,18 @@ class RouterFunctionDslTests {
|
||||||
path("/baz", ::handle)
|
path("/baz", ::handle)
|
||||||
GET("/rendering") { RenderingResponse.create("index").build() }
|
GET("/rendering") { RenderingResponse.create("index").build() }
|
||||||
add(otherRouter)
|
add(otherRouter)
|
||||||
|
add(filterRouter)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val filterRouter = router {
|
||||||
|
"/filter" { request ->
|
||||||
|
ok().header("foo", request.headers().firstHeader("foo")).build()
|
||||||
|
}
|
||||||
|
|
||||||
|
filter { request, next ->
|
||||||
|
val newRequest = ServerRequest.from(request).apply { header("foo", "bar") }.build()
|
||||||
|
next(newRequest)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private val otherRouter = router {
|
private val otherRouter = router {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue