diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt index c6f35ac2f23..ae878582030 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt @@ -653,7 +653,7 @@ class RouterFunctionDsl internal constructor (private val init: RouterFunctionDs fun filter(filterFunction: (ServerRequest, (ServerRequest) -> Mono) -> Mono) { builder.filter { request, next -> filterFunction(request) { - next.handle(request) + next.handle(it) } } } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt index 4392b04fbf1..4edd902ca9d 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt @@ -24,11 +24,13 @@ import org.springframework.http.HttpHeaders.* import org.springframework.http.HttpMethod.* import org.springframework.http.HttpStatus import org.springframework.http.MediaType.* +import org.springframework.web.reactive.function.server.support.ServerRequestWrapper import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.* import org.springframework.web.testfixture.server.MockServerWebExchange import org.springframework.web.reactive.function.server.AttributesTestVisitor import reactor.core.publisher.Mono import reactor.test.StepVerifier +import java.security.Principal /** * Tests for [RouterFunctionDsl]. @@ -169,6 +171,53 @@ class RouterFunctionDslTests { assertThat(visitor.visitCount()).isEqualTo(7); } + @Test + fun acceptFilterAndPOST() { + val mockRequest = post("https://example.com/filter") + .header(ACCEPT, APPLICATION_JSON_VALUE) + .build() + val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList()) + StepVerifier.create(filteredRouter.route(request).flatMap { it.handle(request) }) + .expectNextCount(1) + .verifyComplete() + } + + private val filteredRouter = router { + POST("/filter", ::handleRequestWrapper) + + filter (TestFilterProvider.provide()) + before { + it + } + after { _, response -> + response + } + onError({it is IllegalStateException}) { _, _ -> + ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build() + } + onError { _, _ -> + ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build() + } + } + + private class TestServerRequestWrapper(delegate: ServerRequest, private val principalOverride: String = "foo"): ServerRequestWrapper(delegate) { + override fun principal(): Mono = Mono.just(Principal { principalOverride }) + } + + private object TestFilterProvider { + fun provide(): (ServerRequest, (ServerRequest) -> Mono) -> Mono = { request, next -> + next(TestServerRequestWrapper(request)) + } + } + + private fun handleRequestWrapper(req: ServerRequest): Mono { + return req.principal() + .flatMap { + assertThat(it.name).isEqualTo("foo") + ServerResponse.ok().build() + } + } + private fun sampleRouter() = router { (GET("/foo/") or GET("/foos/")) { req -> handle(req) } "/api".nest {