Improve parity between Java and Kotlin router DSL
This commit adds following functions to the Kotlin DSL: add, filter, before, after and onError. Closes gh-23524
This commit is contained in:
parent
7a1a8e1623
commit
1dfe304da4
|
@ -17,6 +17,7 @@
|
|||
package org.springframework.web.reactive.function.server
|
||||
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.reactive.awaitFirst
|
||||
import kotlinx.coroutines.reactor.mono
|
||||
import org.springframework.core.io.Resource
|
||||
import org.springframework.http.HttpMethod
|
||||
|
@ -64,7 +65,8 @@ fun coRouter(routes: (CoRouterFunctionDsl.() -> Unit)) =
|
|||
*/
|
||||
class CoRouterFunctionDsl(private val init: (CoRouterFunctionDsl.() -> Unit)) {
|
||||
|
||||
private val builder = RouterFunctions.route()
|
||||
@PublishedApi
|
||||
internal val builder = RouterFunctions.route()
|
||||
|
||||
/**
|
||||
* Return a composed request predicate that tests against both this predicate AND
|
||||
|
@ -510,6 +512,80 @@ class CoRouterFunctionDsl(private val init: (CoRouterFunctionDsl.() -> Unit)) {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge externally defined router functions into this one.
|
||||
* @param routerFunction the router function to be added
|
||||
* @since 5.2
|
||||
*/
|
||||
fun add(routerFunction: RouterFunction<ServerResponse>) {
|
||||
builder.add(routerFunction)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters all routes created by this router with the given filter function. Filter
|
||||
* functions are typically used to address cross-cutting concerns, such as logging,
|
||||
* security, etc.
|
||||
* @param filterFunction the function to filter all routes built by this router
|
||||
* @since 5.2
|
||||
*/
|
||||
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
|
||||
builder.filter { serverRequest, handlerFunction ->
|
||||
mono(Dispatchers.Unconfined) {
|
||||
filterFunction(serverRequest) {
|
||||
handlerFunction.handle(serverRequest).awaitFirst()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter the request object for all routes created by this builder with the given request
|
||||
* processing function. Filters are typically used to address cross-cutting concerns, such
|
||||
* as logging, security, etc.
|
||||
* @param requestProcessor a function that transforms the request
|
||||
* @since 5.2
|
||||
*/
|
||||
fun before(requestProcessor: (ServerRequest) -> ServerRequest) {
|
||||
builder.before(requestProcessor)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter the response object for all routes created by this builder with the given response
|
||||
* processing function. Filters are typically used to address cross-cutting concerns, such
|
||||
* as logging, security, etc.
|
||||
* @param responseProcessor a function that transforms the response
|
||||
* @since 5.2
|
||||
*/
|
||||
fun after(responseProcessor: (ServerRequest, ServerResponse) -> ServerResponse) {
|
||||
builder.after(responseProcessor)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters all exceptions that match the predicate by applying the given response provider
|
||||
* function.
|
||||
* @param predicate the type of exception to filter
|
||||
* @param responseProvider a function that creates a response
|
||||
* @since 5.2
|
||||
*/
|
||||
fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
|
||||
builder.onError(predicate) { throwable, request ->
|
||||
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters all exceptions that match the predicate by applying the given response provider
|
||||
* function.
|
||||
* @param E the type of exception to filter
|
||||
* @param responseProvider a function that creates a response
|
||||
* @since 5.2
|
||||
*/
|
||||
inline fun <reified E : Throwable> onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
|
||||
builder.onError({it is E}) { throwable, request ->
|
||||
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a composed routing function created from all the registered routes.
|
||||
*/
|
||||
|
|
|
@ -62,7 +62,8 @@ fun router(routes: RouterFunctionDsl.() -> Unit) = RouterFunctionDsl(routes).bui
|
|||
*/
|
||||
class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) {
|
||||
|
||||
private val builder = RouterFunctions.route()
|
||||
@PublishedApi
|
||||
internal val builder = RouterFunctions.route()
|
||||
|
||||
/**
|
||||
* Return a composed request predicate that tests against both this predicate AND
|
||||
|
@ -505,6 +506,83 @@ class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) {
|
|||
builder.resources(lookupFunction)
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge externally defined router functions into this one.
|
||||
* @param routerFunction the router function to be added
|
||||
* @since 5.2
|
||||
*/
|
||||
fun add(routerFunction: RouterFunction<ServerResponse>) {
|
||||
builder.add(routerFunction)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters all routes created by this router with the given filter function. Filter
|
||||
* functions are typically used to address cross-cutting concerns, such as logging,
|
||||
* security, etc.
|
||||
* @param filterFunction the function to filter all routes built by this router
|
||||
* @since 5.2
|
||||
*/
|
||||
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> Mono<ServerResponse>) -> Mono<ServerResponse>) {
|
||||
builder.filter { request, next ->
|
||||
filterFunction(request) {
|
||||
next.handle(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter the request object for all routes created by this builder with the given request
|
||||
* processing function. Filters are typically used to address cross-cutting concerns, such
|
||||
* as logging, security, etc.
|
||||
* @param requestProcessor a function that transforms the request
|
||||
* @since 5.2
|
||||
*/
|
||||
fun before(requestProcessor: (ServerRequest) -> ServerRequest) {
|
||||
builder.before(requestProcessor)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter the response object for all routes created by this builder with the given response
|
||||
* processing function. Filters are typically used to address cross-cutting concerns, such
|
||||
* as logging, security, etc.
|
||||
* @param responseProcessor a function that transforms the response
|
||||
* @since 5.2
|
||||
*/
|
||||
fun after(responseProcessor: (ServerRequest, ServerResponse) -> ServerResponse) {
|
||||
builder.after(responseProcessor)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters all exceptions that match the predicate by applying the given response provider
|
||||
* function.
|
||||
* @param predicate the type of exception to filter
|
||||
* @param responseProvider a function that creates a response
|
||||
* @since 5.2
|
||||
*/
|
||||
fun onError(predicate: (Throwable) -> Boolean, responseProvider: (Throwable, ServerRequest) -> Mono<ServerResponse>) {
|
||||
builder.onError(predicate, responseProvider)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters all exceptions that match the predicate by applying the given response provider
|
||||
* function.
|
||||
* @param E the type of exception to filter
|
||||
* @param responseProvider a function that creates a response
|
||||
* @since 5.2
|
||||
*/
|
||||
inline fun <reified E : Throwable> onError(noinline responseProvider: (Throwable, ServerRequest) -> Mono<ServerResponse>) {
|
||||
builder.onError({it is E}, responseProvider)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a composed routing function created from all the registered routes.
|
||||
* @since 5.1
|
||||
*/
|
||||
internal fun build(): RouterFunction<ServerResponse> {
|
||||
init()
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a builder with the status code and headers of the given response.
|
||||
* @param other the response to copy the status and headers from
|
||||
|
@ -621,13 +699,4 @@ class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) {
|
|||
fun unprocessableEntity(): ServerResponse.BodyBuilder =
|
||||
ServerResponse.unprocessableEntity()
|
||||
|
||||
/**
|
||||
* Return a composed routing function created from all the registered routes.
|
||||
* @since 5.1
|
||||
*/
|
||||
internal fun build(): RouterFunction<ServerResponse> {
|
||||
init()
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test
|
|||
import org.springframework.core.io.ClassPathResource
|
||||
import org.springframework.http.HttpHeaders.*
|
||||
import org.springframework.http.HttpMethod.*
|
||||
import org.springframework.http.HttpStatus
|
||||
import org.springframework.http.MediaType.*
|
||||
import org.springframework.web.reactive.function.server.MockServerRequest.builder
|
||||
import reactor.test.StepVerifier
|
||||
|
@ -172,6 +173,28 @@ class CoRouterFunctionDslTests {
|
|||
}
|
||||
path("/baz", ::handle)
|
||||
GET("/rendering") { RenderingResponse.create("index").buildAndAwait() }
|
||||
add(otherRouter)
|
||||
}
|
||||
|
||||
private val otherRouter = router {
|
||||
"/other" {
|
||||
ok().build()
|
||||
}
|
||||
filter { request, next ->
|
||||
next(request)
|
||||
}
|
||||
before {
|
||||
it
|
||||
}
|
||||
after { _, response ->
|
||||
response
|
||||
}
|
||||
onError({it is IllegalStateException}) { _, _ ->
|
||||
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
|
||||
}
|
||||
onError<IllegalStateException> { _, _ ->
|
||||
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNUSED_PARAMETER")
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test
|
|||
import org.springframework.core.io.ClassPathResource
|
||||
import org.springframework.http.HttpHeaders.*
|
||||
import org.springframework.http.HttpMethod.*
|
||||
import org.springframework.http.HttpStatus
|
||||
import org.springframework.http.MediaType.*
|
||||
import org.springframework.web.reactive.function.server.MockServerRequest.builder
|
||||
import reactor.core.publisher.Mono
|
||||
|
@ -173,6 +174,28 @@ class RouterFunctionDslTests {
|
|||
}
|
||||
path("/baz", ::handle)
|
||||
GET("/rendering") { RenderingResponse.create("index").build() }
|
||||
add(otherRouter)
|
||||
}
|
||||
|
||||
private val otherRouter = router {
|
||||
"/other" {
|
||||
ok().build()
|
||||
}
|
||||
filter { request, next ->
|
||||
next(request)
|
||||
}
|
||||
before {
|
||||
it
|
||||
}
|
||||
after { _, response ->
|
||||
response
|
||||
}
|
||||
onError({it is IllegalStateException}) { _, _ ->
|
||||
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
|
||||
}
|
||||
onError<IllegalStateException> { _, _ ->
|
||||
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNUSED_PARAMETER")
|
||||
|
|
|
@ -60,7 +60,8 @@ fun router(routes: (RouterFunctionDsl.() -> Unit)) = RouterFunctionDsl(routes).b
|
|||
*/
|
||||
class RouterFunctionDsl(private val init: (RouterFunctionDsl.() -> Unit)) {
|
||||
|
||||
private val builder = RouterFunctions.route()
|
||||
@PublishedApi
|
||||
internal val builder = RouterFunctions.route()
|
||||
|
||||
/**
|
||||
* Return a composed request predicate that tests against both this predicate AND
|
||||
|
@ -504,6 +505,74 @@ class RouterFunctionDsl(private val init: (RouterFunctionDsl.() -> Unit)) {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge externally defined router functions into this one.
|
||||
* @param routerFunction the router function to be added
|
||||
* @since 5.2
|
||||
*/
|
||||
fun add(routerFunction: RouterFunction<ServerResponse>) {
|
||||
builder.add(routerFunction)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters all routes created by this router with the given filter function. Filter
|
||||
* functions are typically used to address cross-cutting concerns, such as logging,
|
||||
* security, etc.
|
||||
* @param filterFunction the function to filter all routes built by this router
|
||||
* @since 5.2
|
||||
*/
|
||||
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> ServerResponse) -> ServerResponse) {
|
||||
builder.filter { request, next ->
|
||||
filterFunction(request) {
|
||||
next.handle(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter the request object for all routes created by this builder with the given request
|
||||
* processing function. Filters are typically used to address cross-cutting concerns, such
|
||||
* as logging, security, etc.
|
||||
* @param requestProcessor a function that transforms the request
|
||||
* @since 5.2
|
||||
*/
|
||||
fun before(requestProcessor: (ServerRequest) -> ServerRequest) {
|
||||
builder.before(requestProcessor)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter the response object for all routes created by this builder with the given response
|
||||
* processing function. Filters are typically used to address cross-cutting concerns, such
|
||||
* as logging, security, etc.
|
||||
* @param responseProcessor a function that transforms the response
|
||||
* @since 5.2
|
||||
*/
|
||||
fun after(responseProcessor: (ServerRequest, ServerResponse) -> ServerResponse) {
|
||||
builder.after(responseProcessor)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters all exceptions that match the predicate by applying the given response provider
|
||||
* function.
|
||||
* @param predicate the type of exception to filter
|
||||
* @param responseProvider a function that creates a response
|
||||
* @since 5.2
|
||||
*/
|
||||
fun onError(predicate: (Throwable) -> Boolean, responseProvider: (Throwable, ServerRequest) -> ServerResponse) {
|
||||
builder.onError(predicate, responseProvider)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters all exceptions that match the predicate by applying the given response provider
|
||||
* function.
|
||||
* @param E the type of exception to filter
|
||||
* @param responseProvider a function that creates a response
|
||||
* @since 5.2
|
||||
*/
|
||||
inline fun <reified E : Throwable> onError(noinline responseProvider: (Throwable, ServerRequest) -> ServerResponse) {
|
||||
builder.onError({it is E}, responseProvider)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a composed routing function created from all the registered routes.
|
||||
*/
|
||||
|
|
|
@ -16,11 +16,13 @@
|
|||
|
||||
package org.springframework.web.servlet.function
|
||||
|
||||
import org.assertj.core.api.Assertions.*
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.springframework.core.io.ClassPathResource
|
||||
import org.springframework.http.HttpHeaders.*
|
||||
import org.springframework.http.HttpMethod.*
|
||||
import org.springframework.http.HttpStatus
|
||||
import org.springframework.http.MediaType.*
|
||||
import org.springframework.mock.web.test.MockHttpServletRequest
|
||||
|
||||
|
@ -124,7 +126,6 @@ class RouterFunctionDslTests {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
private fun sampleRouter() = router {
|
||||
(GET("/foo/") or GET("/foos/")) { req -> handle(req) }
|
||||
"/api".nest {
|
||||
|
@ -157,6 +158,28 @@ class RouterFunctionDslTests {
|
|||
}
|
||||
path("/baz", ::handle)
|
||||
GET("/rendering") { RenderingResponse.create("index").build() }
|
||||
add(otherRouter)
|
||||
}
|
||||
|
||||
private val otherRouter = router {
|
||||
"/other" {
|
||||
ok().build()
|
||||
}
|
||||
filter { request, next ->
|
||||
next(request)
|
||||
}
|
||||
before {
|
||||
it
|
||||
}
|
||||
after { _, response ->
|
||||
response
|
||||
}
|
||||
onError({it is IllegalStateException}) { _, _ ->
|
||||
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
|
||||
}
|
||||
onError<IllegalStateException> { _, _ ->
|
||||
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNUSED_PARAMETER")
|
||||
|
|
|
@ -751,17 +751,17 @@ For instance, consider the following example:
|
|||
[source,java,indent=0,subs="verbatim,quotes",role="primary"]
|
||||
.Java
|
||||
----
|
||||
RouterFunction<ServerResponse> route = route()
|
||||
.path("/person", b1 -> b1
|
||||
.nest(accept(APPLICATION_JSON), b2 -> b2
|
||||
.GET("/{id}", handler::getPerson)
|
||||
.GET("", handler::listPeople)
|
||||
.before(request -> ServerRequest.from(request) // <1>
|
||||
.header("X-RequestHeader", "Value")
|
||||
.build()))
|
||||
.POST("/person", handler::createPerson))
|
||||
.after((request, response) -> logResponse(response)) // <2>
|
||||
.build();
|
||||
RouterFunction<ServerResponse> route = route()
|
||||
.path("/person", b1 -> b1
|
||||
.nest(accept(APPLICATION_JSON), b2 -> b2
|
||||
.GET("/{id}", handler::getPerson)
|
||||
.GET("", handler::listPeople)
|
||||
.before(request -> ServerRequest.from(request) // <1>
|
||||
.header("X-RequestHeader", "Value")
|
||||
.build()))
|
||||
.POST("/person", handler::createPerson))
|
||||
.after((request, response) -> logResponse(response)) // <2>
|
||||
.build();
|
||||
----
|
||||
<1> The `before` filter that adds a custom request header is only applied to the two GET routes.
|
||||
<2> The `after` filter that logs the response is applied to all routes, including the nested ones.
|
||||
|
@ -769,8 +769,23 @@ RouterFunction<ServerResponse> route = route()
|
|||
[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"]
|
||||
.Kotlin
|
||||
----
|
||||
// TODO when https://github.com/spring-projects/spring-framework/issues/23526 will be fixed
|
||||
val route = router {
|
||||
"/person".nest {
|
||||
GET("/{id}", handler::getPerson)
|
||||
GET("", handler::listPeople)
|
||||
before { // <1>
|
||||
ServerRequest.from(it)
|
||||
.header("X-RequestHeader", "Value").build()
|
||||
}
|
||||
POST("/person", handler::createPerson)
|
||||
after { _, response -> // <2>
|
||||
logResponse(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
----
|
||||
<1> The `before` filter that adds a custom request header is only applied to the two GET routes.
|
||||
<2> The `after` filter that logs the response is applied to all routes, including the nested ones.
|
||||
|
||||
|
||||
The `filter` method on the router builder takes a `HandlerFilterFunction`: a
|
||||
|
@ -807,7 +822,23 @@ The following example shows how to do so:
|
|||
[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"]
|
||||
.Kotlin
|
||||
----
|
||||
// TODO when https://github.com/spring-projects/spring-framework/issues/23526 will be fixed
|
||||
val securityManager: SecurityManager = ...
|
||||
|
||||
val route = router {
|
||||
("/person" and accept(APPLICATION_JSON)).nest {
|
||||
GET("/{id}", handler::getPerson)
|
||||
GET("", handler::listPeople)
|
||||
POST("/person", handler::createPerson)
|
||||
filter { request, next ->
|
||||
if (securityManager.allowAccessTo(request.path())) {
|
||||
next(request)
|
||||
}
|
||||
else {
|
||||
status(UNAUTHORIZED).build();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
----
|
||||
|
||||
The preceding example demonstrates that invoking the `next.handle(ServerRequest)` is optional.
|
||||
|
|
Loading…
Reference in New Issue