Pass current request to filter in RouterFunction extensions

This commit ensures that the proper request is passed in the
Kotlin RouterFunction extension DSL.

Closes gh-27086
This commit is contained in:
ijonathanc 2021-06-21 11:38:08 +01:00 committed by Arjen Poutsma
parent a975b9d5da
commit 9d263668d5
2 changed files with 50 additions and 1 deletions

View File

@ -653,7 +653,7 @@ class RouterFunctionDsl internal constructor (private val init: RouterFunctionDs
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> Mono<ServerResponse>) -> Mono<ServerResponse>) {
builder.filter { request, next ->
filterFunction(request) {
next.handle(request)
next.handle(it)
}
}
}

View File

@ -24,11 +24,13 @@ import org.springframework.http.HttpHeaders.*
import org.springframework.http.HttpMethod.*
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.*
import org.springframework.web.reactive.function.server.support.ServerRequestWrapper
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
import org.springframework.web.testfixture.server.MockServerWebExchange
import org.springframework.web.reactive.function.server.AttributesTestVisitor
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
import java.security.Principal
/**
* Tests for [RouterFunctionDsl].
@ -169,6 +171,53 @@ class RouterFunctionDslTests {
assertThat(visitor.visitCount()).isEqualTo(7);
}
@Test
fun acceptFilterAndPOST() {
val mockRequest = post("https://example.com/filter")
.header(ACCEPT, APPLICATION_JSON_VALUE)
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(filteredRouter.route(request).flatMap { it.handle(request) })
.expectNextCount(1)
.verifyComplete()
}
private val filteredRouter = router {
POST("/filter", ::handleRequestWrapper)
filter (TestFilterProvider.provide())
before {
it
}
after { _, response ->
response
}
onError({it is IllegalStateException}) { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
onError<IllegalStateException> { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
}
private class TestServerRequestWrapper(delegate: ServerRequest, private val principalOverride: String = "foo"): ServerRequestWrapper(delegate) {
override fun principal(): Mono<out Principal> = Mono.just(Principal { principalOverride })
}
private object TestFilterProvider {
fun provide(): (ServerRequest, (ServerRequest) -> Mono<ServerResponse>) -> Mono<ServerResponse> = { request, next ->
next(TestServerRequestWrapper(request))
}
}
private fun handleRequestWrapper(req: ServerRequest): Mono<ServerResponse> {
return req.principal()
.flatMap {
assertThat(it.name).isEqualTo("foo")
ServerResponse.ok().build()
}
}
private fun sampleRouter() = router {
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
"/api".nest {