Pass current request to filter in RouterFunction extensions
This commit ensures that the proper request is passed in the Kotlin RouterFunction extension DSL. Closes gh-27086
This commit is contained in:
parent
a975b9d5da
commit
9d263668d5
|
|
@ -653,7 +653,7 @@ class RouterFunctionDsl internal constructor (private val init: RouterFunctionDs
|
|||
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> Mono<ServerResponse>) -> Mono<ServerResponse>) {
|
||||
builder.filter { request, next ->
|
||||
filterFunction(request) {
|
||||
next.handle(request)
|
||||
next.handle(it)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,11 +24,13 @@ import org.springframework.http.HttpHeaders.*
|
|||
import org.springframework.http.HttpMethod.*
|
||||
import org.springframework.http.HttpStatus
|
||||
import org.springframework.http.MediaType.*
|
||||
import org.springframework.web.reactive.function.server.support.ServerRequestWrapper
|
||||
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
|
||||
import org.springframework.web.testfixture.server.MockServerWebExchange
|
||||
import org.springframework.web.reactive.function.server.AttributesTestVisitor
|
||||
import reactor.core.publisher.Mono
|
||||
import reactor.test.StepVerifier
|
||||
import java.security.Principal
|
||||
|
||||
/**
|
||||
* Tests for [RouterFunctionDsl].
|
||||
|
|
@ -169,6 +171,53 @@ class RouterFunctionDslTests {
|
|||
assertThat(visitor.visitCount()).isEqualTo(7);
|
||||
}
|
||||
|
||||
@Test
|
||||
fun acceptFilterAndPOST() {
|
||||
val mockRequest = post("https://example.com/filter")
|
||||
.header(ACCEPT, APPLICATION_JSON_VALUE)
|
||||
.build()
|
||||
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
|
||||
StepVerifier.create(filteredRouter.route(request).flatMap { it.handle(request) })
|
||||
.expectNextCount(1)
|
||||
.verifyComplete()
|
||||
}
|
||||
|
||||
private val filteredRouter = router {
|
||||
POST("/filter", ::handleRequestWrapper)
|
||||
|
||||
filter (TestFilterProvider.provide())
|
||||
before {
|
||||
it
|
||||
}
|
||||
after { _, response ->
|
||||
response
|
||||
}
|
||||
onError({it is IllegalStateException}) { _, _ ->
|
||||
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
|
||||
}
|
||||
onError<IllegalStateException> { _, _ ->
|
||||
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
|
||||
}
|
||||
}
|
||||
|
||||
private class TestServerRequestWrapper(delegate: ServerRequest, private val principalOverride: String = "foo"): ServerRequestWrapper(delegate) {
|
||||
override fun principal(): Mono<out Principal> = Mono.just(Principal { principalOverride })
|
||||
}
|
||||
|
||||
private object TestFilterProvider {
|
||||
fun provide(): (ServerRequest, (ServerRequest) -> Mono<ServerResponse>) -> Mono<ServerResponse> = { request, next ->
|
||||
next(TestServerRequestWrapper(request))
|
||||
}
|
||||
}
|
||||
|
||||
private fun handleRequestWrapper(req: ServerRequest): Mono<ServerResponse> {
|
||||
return req.principal()
|
||||
.flatMap {
|
||||
assertThat(it.name).isEqualTo("foo")
|
||||
ServerResponse.ok().build()
|
||||
}
|
||||
}
|
||||
|
||||
private fun sampleRouter() = router {
|
||||
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
|
||||
"/api".nest {
|
||||
|
|
|
|||
Loading…
Reference in New Issue