Pass current request to filter in RouterFunction extensions
This commit ensures that the proper request is passed in the Kotlin RouterFunction extension DSL. Closes gh-27086
This commit is contained in:
parent
a975b9d5da
commit
9d263668d5
|
|
@ -653,7 +653,7 @@ class RouterFunctionDsl internal constructor (private val init: RouterFunctionDs
|
||||||
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> Mono<ServerResponse>) -> Mono<ServerResponse>) {
|
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> Mono<ServerResponse>) -> Mono<ServerResponse>) {
|
||||||
builder.filter { request, next ->
|
builder.filter { request, next ->
|
||||||
filterFunction(request) {
|
filterFunction(request) {
|
||||||
next.handle(request)
|
next.handle(it)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,11 +24,13 @@ import org.springframework.http.HttpHeaders.*
|
||||||
import org.springframework.http.HttpMethod.*
|
import org.springframework.http.HttpMethod.*
|
||||||
import org.springframework.http.HttpStatus
|
import org.springframework.http.HttpStatus
|
||||||
import org.springframework.http.MediaType.*
|
import org.springframework.http.MediaType.*
|
||||||
|
import org.springframework.web.reactive.function.server.support.ServerRequestWrapper
|
||||||
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
|
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
|
||||||
import org.springframework.web.testfixture.server.MockServerWebExchange
|
import org.springframework.web.testfixture.server.MockServerWebExchange
|
||||||
import org.springframework.web.reactive.function.server.AttributesTestVisitor
|
import org.springframework.web.reactive.function.server.AttributesTestVisitor
|
||||||
import reactor.core.publisher.Mono
|
import reactor.core.publisher.Mono
|
||||||
import reactor.test.StepVerifier
|
import reactor.test.StepVerifier
|
||||||
|
import java.security.Principal
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for [RouterFunctionDsl].
|
* Tests for [RouterFunctionDsl].
|
||||||
|
|
@ -169,6 +171,53 @@ class RouterFunctionDslTests {
|
||||||
assertThat(visitor.visitCount()).isEqualTo(7);
|
assertThat(visitor.visitCount()).isEqualTo(7);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun acceptFilterAndPOST() {
|
||||||
|
val mockRequest = post("https://example.com/filter")
|
||||||
|
.header(ACCEPT, APPLICATION_JSON_VALUE)
|
||||||
|
.build()
|
||||||
|
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
|
||||||
|
StepVerifier.create(filteredRouter.route(request).flatMap { it.handle(request) })
|
||||||
|
.expectNextCount(1)
|
||||||
|
.verifyComplete()
|
||||||
|
}
|
||||||
|
|
||||||
|
private val filteredRouter = router {
|
||||||
|
POST("/filter", ::handleRequestWrapper)
|
||||||
|
|
||||||
|
filter (TestFilterProvider.provide())
|
||||||
|
before {
|
||||||
|
it
|
||||||
|
}
|
||||||
|
after { _, response ->
|
||||||
|
response
|
||||||
|
}
|
||||||
|
onError({it is IllegalStateException}) { _, _ ->
|
||||||
|
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
|
||||||
|
}
|
||||||
|
onError<IllegalStateException> { _, _ ->
|
||||||
|
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class TestServerRequestWrapper(delegate: ServerRequest, private val principalOverride: String = "foo"): ServerRequestWrapper(delegate) {
|
||||||
|
override fun principal(): Mono<out Principal> = Mono.just(Principal { principalOverride })
|
||||||
|
}
|
||||||
|
|
||||||
|
private object TestFilterProvider {
|
||||||
|
fun provide(): (ServerRequest, (ServerRequest) -> Mono<ServerResponse>) -> Mono<ServerResponse> = { request, next ->
|
||||||
|
next(TestServerRequestWrapper(request))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun handleRequestWrapper(req: ServerRequest): Mono<ServerResponse> {
|
||||||
|
return req.principal()
|
||||||
|
.flatMap {
|
||||||
|
assertThat(it.name).isEqualTo("foo")
|
||||||
|
ServerResponse.ok().build()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private fun sampleRouter() = router {
|
private fun sampleRouter() = router {
|
||||||
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
|
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
|
||||||
"/api".nest {
|
"/api".nest {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue