Improve parity between Java and Kotlin router DSL

This commit adds following functions to the Kotlin DSL:
add, filter, before, after and onError.

Closes gh-23524
This commit is contained in:
Sebastien Deleuze 2019-09-17 12:04:37 +02:00
parent 7a1a8e1623
commit 1dfe304da4
7 changed files with 341 additions and 27 deletions

View File

@ -17,6 +17,7 @@
package org.springframework.web.reactive.function.server
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.reactive.awaitFirst
import kotlinx.coroutines.reactor.mono
import org.springframework.core.io.Resource
import org.springframework.http.HttpMethod
@ -64,7 +65,8 @@ fun coRouter(routes: (CoRouterFunctionDsl.() -> Unit)) =
*/
class CoRouterFunctionDsl(private val init: (CoRouterFunctionDsl.() -> Unit)) {
private val builder = RouterFunctions.route()
@PublishedApi
internal val builder = RouterFunctions.route()
/**
* Return a composed request predicate that tests against both this predicate AND
@ -510,6 +512,80 @@ class CoRouterFunctionDsl(private val init: (CoRouterFunctionDsl.() -> Unit)) {
}
}
/**
* Merge externally defined router functions into this one.
* @param routerFunction the router function to be added
* @since 5.2
*/
fun add(routerFunction: RouterFunction<ServerResponse>) {
builder.add(routerFunction)
}
/**
* Filters all routes created by this router with the given filter function. Filter
* functions are typically used to address cross-cutting concerns, such as logging,
* security, etc.
* @param filterFunction the function to filter all routes built by this router
* @since 5.2
*/
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
builder.filter { serverRequest, handlerFunction ->
mono(Dispatchers.Unconfined) {
filterFunction(serverRequest) {
handlerFunction.handle(serverRequest).awaitFirst()
}
}
}
}
/**
* Filter the request object for all routes created by this builder with the given request
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param requestProcessor a function that transforms the request
* @since 5.2
*/
fun before(requestProcessor: (ServerRequest) -> ServerRequest) {
builder.before(requestProcessor)
}
/**
* Filter the response object for all routes created by this builder with the given response
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param responseProcessor a function that transforms the response
* @since 5.2
*/
fun after(responseProcessor: (ServerRequest, ServerResponse) -> ServerResponse) {
builder.after(responseProcessor)
}
/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param predicate the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
builder.onError(predicate) { throwable, request ->
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
}
}
/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param E the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
inline fun <reified E : Throwable> onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
builder.onError({it is E}) { throwable, request ->
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
}
}
/**
* Return a composed routing function created from all the registered routes.
*/

View File

@ -62,7 +62,8 @@ fun router(routes: RouterFunctionDsl.() -> Unit) = RouterFunctionDsl(routes).bui
*/
class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) {
private val builder = RouterFunctions.route()
@PublishedApi
internal val builder = RouterFunctions.route()
/**
* Return a composed request predicate that tests against both this predicate AND
@ -505,6 +506,83 @@ class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) {
builder.resources(lookupFunction)
}
/**
* Merge externally defined router functions into this one.
* @param routerFunction the router function to be added
* @since 5.2
*/
fun add(routerFunction: RouterFunction<ServerResponse>) {
builder.add(routerFunction)
}
/**
* Filters all routes created by this router with the given filter function. Filter
* functions are typically used to address cross-cutting concerns, such as logging,
* security, etc.
* @param filterFunction the function to filter all routes built by this router
* @since 5.2
*/
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> Mono<ServerResponse>) -> Mono<ServerResponse>) {
builder.filter { request, next ->
filterFunction(request) {
next.handle(request)
}
}
}
/**
* Filter the request object for all routes created by this builder with the given request
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param requestProcessor a function that transforms the request
* @since 5.2
*/
fun before(requestProcessor: (ServerRequest) -> ServerRequest) {
builder.before(requestProcessor)
}
/**
* Filter the response object for all routes created by this builder with the given response
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param responseProcessor a function that transforms the response
* @since 5.2
*/
fun after(responseProcessor: (ServerRequest, ServerResponse) -> ServerResponse) {
builder.after(responseProcessor)
}
/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param predicate the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
fun onError(predicate: (Throwable) -> Boolean, responseProvider: (Throwable, ServerRequest) -> Mono<ServerResponse>) {
builder.onError(predicate, responseProvider)
}
/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param E the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
inline fun <reified E : Throwable> onError(noinline responseProvider: (Throwable, ServerRequest) -> Mono<ServerResponse>) {
builder.onError({it is E}, responseProvider)
}
/**
* Return a composed routing function created from all the registered routes.
* @since 5.1
*/
internal fun build(): RouterFunction<ServerResponse> {
init()
return builder.build()
}
/**
* Create a builder with the status code and headers of the given response.
* @param other the response to copy the status and headers from
@ -621,13 +699,4 @@ class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) {
fun unprocessableEntity(): ServerResponse.BodyBuilder =
ServerResponse.unprocessableEntity()
/**
* Return a composed routing function created from all the registered routes.
* @since 5.1
*/
internal fun build(): RouterFunction<ServerResponse> {
init()
return builder.build()
}
}

View File

@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test
import org.springframework.core.io.ClassPathResource
import org.springframework.http.HttpHeaders.*
import org.springframework.http.HttpMethod.*
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.*
import org.springframework.web.reactive.function.server.MockServerRequest.builder
import reactor.test.StepVerifier
@ -172,6 +173,28 @@ class CoRouterFunctionDslTests {
}
path("/baz", ::handle)
GET("/rendering") { RenderingResponse.create("index").buildAndAwait() }
add(otherRouter)
}
private val otherRouter = router {
"/other" {
ok().build()
}
filter { request, next ->
next(request)
}
before {
it
}
after { _, response ->
response
}
onError({it is IllegalStateException}) { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
onError<IllegalStateException> { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
}
@Suppress("UNUSED_PARAMETER")

View File

@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test
import org.springframework.core.io.ClassPathResource
import org.springframework.http.HttpHeaders.*
import org.springframework.http.HttpMethod.*
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.*
import org.springframework.web.reactive.function.server.MockServerRequest.builder
import reactor.core.publisher.Mono
@ -173,6 +174,28 @@ class RouterFunctionDslTests {
}
path("/baz", ::handle)
GET("/rendering") { RenderingResponse.create("index").build() }
add(otherRouter)
}
private val otherRouter = router {
"/other" {
ok().build()
}
filter { request, next ->
next(request)
}
before {
it
}
after { _, response ->
response
}
onError({it is IllegalStateException}) { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
onError<IllegalStateException> { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
}
@Suppress("UNUSED_PARAMETER")

View File

@ -60,7 +60,8 @@ fun router(routes: (RouterFunctionDsl.() -> Unit)) = RouterFunctionDsl(routes).b
*/
class RouterFunctionDsl(private val init: (RouterFunctionDsl.() -> Unit)) {
private val builder = RouterFunctions.route()
@PublishedApi
internal val builder = RouterFunctions.route()
/**
* Return a composed request predicate that tests against both this predicate AND
@ -504,6 +505,74 @@ class RouterFunctionDsl(private val init: (RouterFunctionDsl.() -> Unit)) {
}
}
/**
* Merge externally defined router functions into this one.
* @param routerFunction the router function to be added
* @since 5.2
*/
fun add(routerFunction: RouterFunction<ServerResponse>) {
builder.add(routerFunction)
}
/**
* Filters all routes created by this router with the given filter function. Filter
* functions are typically used to address cross-cutting concerns, such as logging,
* security, etc.
* @param filterFunction the function to filter all routes built by this router
* @since 5.2
*/
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> ServerResponse) -> ServerResponse) {
builder.filter { request, next ->
filterFunction(request) {
next.handle(request)
}
}
}
/**
* Filter the request object for all routes created by this builder with the given request
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param requestProcessor a function that transforms the request
* @since 5.2
*/
fun before(requestProcessor: (ServerRequest) -> ServerRequest) {
builder.before(requestProcessor)
}
/**
* Filter the response object for all routes created by this builder with the given response
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param responseProcessor a function that transforms the response
* @since 5.2
*/
fun after(responseProcessor: (ServerRequest, ServerResponse) -> ServerResponse) {
builder.after(responseProcessor)
}
/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param predicate the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
fun onError(predicate: (Throwable) -> Boolean, responseProvider: (Throwable, ServerRequest) -> ServerResponse) {
builder.onError(predicate, responseProvider)
}
/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param E the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
inline fun <reified E : Throwable> onError(noinline responseProvider: (Throwable, ServerRequest) -> ServerResponse) {
builder.onError({it is E}, responseProvider)
}
/**
* Return a composed routing function created from all the registered routes.
*/

View File

@ -16,11 +16,13 @@
package org.springframework.web.servlet.function
import org.assertj.core.api.Assertions.*
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.jupiter.api.Test
import org.springframework.core.io.ClassPathResource
import org.springframework.http.HttpHeaders.*
import org.springframework.http.HttpMethod.*
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.*
import org.springframework.mock.web.test.MockHttpServletRequest
@ -124,7 +126,6 @@ class RouterFunctionDslTests {
}
}
private fun sampleRouter() = router {
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
"/api".nest {
@ -157,6 +158,28 @@ class RouterFunctionDslTests {
}
path("/baz", ::handle)
GET("/rendering") { RenderingResponse.create("index").build() }
add(otherRouter)
}
private val otherRouter = router {
"/other" {
ok().build()
}
filter { request, next ->
next(request)
}
before {
it
}
after { _, response ->
response
}
onError({it is IllegalStateException}) { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
onError<IllegalStateException> { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
}
@Suppress("UNUSED_PARAMETER")

View File

@ -751,17 +751,17 @@ For instance, consider the following example:
[source,java,indent=0,subs="verbatim,quotes",role="primary"]
.Java
----
RouterFunction<ServerResponse> route = route()
.path("/person", b1 -> b1
.nest(accept(APPLICATION_JSON), b2 -> b2
.GET("/{id}", handler::getPerson)
.GET("", handler::listPeople)
.before(request -> ServerRequest.from(request) // <1>
.header("X-RequestHeader", "Value")
.build()))
.POST("/person", handler::createPerson))
.after((request, response) -> logResponse(response)) // <2>
.build();
RouterFunction<ServerResponse> route = route()
.path("/person", b1 -> b1
.nest(accept(APPLICATION_JSON), b2 -> b2
.GET("/{id}", handler::getPerson)
.GET("", handler::listPeople)
.before(request -> ServerRequest.from(request) // <1>
.header("X-RequestHeader", "Value")
.build()))
.POST("/person", handler::createPerson))
.after((request, response) -> logResponse(response)) // <2>
.build();
----
<1> The `before` filter that adds a custom request header is only applied to the two GET routes.
<2> The `after` filter that logs the response is applied to all routes, including the nested ones.
@ -769,8 +769,23 @@ RouterFunction<ServerResponse> route = route()
[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"]
.Kotlin
----
// TODO when https://github.com/spring-projects/spring-framework/issues/23526 will be fixed
val route = router {
"/person".nest {
GET("/{id}", handler::getPerson)
GET("", handler::listPeople)
before { // <1>
ServerRequest.from(it)
.header("X-RequestHeader", "Value").build()
}
POST("/person", handler::createPerson)
after { _, response -> // <2>
logResponse(response)
}
}
}
----
<1> The `before` filter that adds a custom request header is only applied to the two GET routes.
<2> The `after` filter that logs the response is applied to all routes, including the nested ones.
The `filter` method on the router builder takes a `HandlerFilterFunction`: a
@ -807,7 +822,23 @@ The following example shows how to do so:
[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"]
.Kotlin
----
// TODO when https://github.com/spring-projects/spring-framework/issues/23526 will be fixed
val securityManager: SecurityManager = ...
val route = router {
("/person" and accept(APPLICATION_JSON)).nest {
GET("/{id}", handler::getPerson)
GET("", handler::listPeople)
POST("/person", handler::createPerson)
filter { request, next ->
if (securityManager.allowAccessTo(request.path())) {
next(request)
}
else {
status(UNAUTHORIZED).build();
}
}
}
}
----
The preceding example demonstrates that invoking the `next.handle(ServerRequest)` is optional.