diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt index 6974faee6d6..f04000ce46d 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt @@ -531,8 +531,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) { builder.filter { serverRequest, handlerFunction -> mono(Dispatchers.Unconfined) { - filterFunction(serverRequest) { - handlerFunction.handle(serverRequest).awaitSingle() + filterFunction(serverRequest) { handlerRequest -> + handlerFunction.handle(handlerRequest).awaitSingle() } } } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt index 1a2bc064463..bdeae8b00af 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt @@ -152,6 +152,16 @@ class CoRouterFunctionDslTests { } } + @Test + fun filtering() { + val mockRequest = get("https://example.com/filter").build() + val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList()) + StepVerifier.create(sampleRouter().route(request).flatMap { it.handle(request) }) + .expectNextMatches { response -> + response.headers().getFirst("foo") == "bar" + } + .verifyComplete() + } private fun sampleRouter() = coRouter { (GET("/foo/") or GET("/foos/")) { req -> handle(req) } @@ -186,6 +196,18 @@ class CoRouterFunctionDslTests { path("/baz", ::handle) GET("/rendering") { RenderingResponse.create("index").buildAndAwait() } add(otherRouter) + add(filterRouter) + } + + private val filterRouter = coRouter { + "/filter" { request -> + ok().header("foo", request.headers().firstHeader("foo")).buildAndAwait() + } + + filter { request, next -> + val newRequest = ServerRequest.from(request).apply { header("foo", "bar") }.build() + next(newRequest) + } } private val otherRouter = router { diff --git a/spring-webmvc/src/main/kotlin/org/springframework/web/servlet/function/RouterFunctionDsl.kt b/spring-webmvc/src/main/kotlin/org/springframework/web/servlet/function/RouterFunctionDsl.kt index 68661676731..88381315df0 100644 --- a/spring-webmvc/src/main/kotlin/org/springframework/web/servlet/function/RouterFunctionDsl.kt +++ b/spring-webmvc/src/main/kotlin/org/springframework/web/servlet/function/RouterFunctionDsl.kt @@ -649,8 +649,8 @@ class RouterFunctionDsl internal constructor (private val init: (RouterFunctionD */ fun filter(filterFunction: (ServerRequest, (ServerRequest) -> ServerResponse) -> ServerResponse) { builder.filter { request, next -> - filterFunction(request) { - next.handle(request) + filterFunction(request) { handlerRequest -> + next.handle(handlerRequest) } } } diff --git a/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/function/RouterFunctionDslTests.kt b/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/function/RouterFunctionDslTests.kt index 7898ded3ed4..750d05d01e3 100644 --- a/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/function/RouterFunctionDslTests.kt +++ b/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/function/RouterFunctionDslTests.kt @@ -127,6 +127,13 @@ class RouterFunctionDslTests { } } + @Test + fun filtering() { + val servletRequest = PathPatternsTestUtils.initRequest("GET", "/filter", true) + val request = DefaultServerRequest(servletRequest, emptyList()) + assertThat(sampleRouter().route(request).get().handle(request).headers().getFirst("foo")).isEqualTo("bar") + } + private fun sampleRouter() = router { (GET("/foo/") or GET("/foos/")) { req -> handle(req) } "/api".nest { @@ -160,6 +167,18 @@ class RouterFunctionDslTests { path("/baz", ::handle) GET("/rendering") { RenderingResponse.create("index").build() } add(otherRouter) + add(filterRouter) + } + + private val filterRouter = router { + "/filter" { request -> + ok().header("foo", request.headers().firstHeader("foo")).build() + } + + filter { request, next -> + val newRequest = ServerRequest.from(request).apply { header("foo", "bar") }.build() + next(newRequest) + } } private val otherRouter = router {