diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensions.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensions.kt index fa213bf212b..a0a0eecd1bf 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensions.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensions.kt @@ -19,6 +19,7 @@ package org.springframework.web.reactive.function.server import org.springframework.core.io.Resource import org.springframework.http.HttpMethod import org.springframework.http.MediaType +import org.springframework.web.reactive.function.server.RequestPredicates.pathPrefix import reactor.core.publisher.Mono /** @@ -26,18 +27,16 @@ import reactor.core.publisher.Mono * write idiomatic Kotlin code as below: * * ```kotlin - * import org.springframework.web.reactive.function.server.RequestPredicates.* - * ... * * @Controller * class FooController : RouterFunction { * * override fun route(req: ServerRequest) = route(req) { - * html().apply { + * accept(TEXT_HTML).apply { * (GET("/user/") or GET("/users/")) { findAllView() } * GET("/user/{login}", this@FooController::findViewById) * } - * json().apply { + * accept(APPLICATION_JSON).apply { * (GET("/api/user/") or GET("/api/users/")) { findAll() } * POST("/api/user/", this@FooController::create) * } @@ -52,24 +51,39 @@ import reactor.core.publisher.Mono * * @since 5.0 * @see Kotlin issue about supporting ::foo for member functions - * @author Sebastien Deleuze + * @author Sebastien De leuze * @author Yevhenii Melnyk */ -fun RouterFunction<*>.route(request: ServerRequest, configure: Routes.() -> Unit) = - Routes().apply(configure).invoke(request) -class Routes { +typealias Routes = RouterDsl.() -> Unit + +fun RouterFunction<*>.route(request: ServerRequest, configure: Routes) = + RouterDsl().apply(configure).invoke(request) + +class RouterDsl { val routes = mutableListOf>() + infix fun RequestPredicate.and(other: String): RequestPredicate = this.and(pathPrefix(other)) + + infix fun RequestPredicate.or(other: String): RequestPredicate = this.or(pathPrefix(other)) + + infix fun String.and(other: RequestPredicate): RequestPredicate = pathPrefix(this).and(other) + + infix fun String.or(other: RequestPredicate): RequestPredicate = pathPrefix(this).or(other) + infix fun RequestPredicate.and(other: RequestPredicate): RequestPredicate = this.and(other) infix fun RequestPredicate.or(other: RequestPredicate): RequestPredicate = this.or(other) operator fun RequestPredicate.not(): RequestPredicate = this.negate() - fun RequestPredicate.route(r: Routes.() -> Unit) { - routes += RouterFunctions.nest(this, Routes().apply(r).router()) + fun RequestPredicate.route(r: Routes) { + routes += RouterFunctions.nest(this, RouterDsl().apply(r).router()) + } + + fun String.route(r: Routes) { + routes += RouterFunctions.nest(pathPrefix(this), RouterDsl().apply(r).router()) } operator fun RequestPredicate.invoke(f: (ServerRequest) -> Mono) { @@ -80,62 +94,98 @@ class Routes { routes += RouterFunctions.route(RequestPredicates.GET(pattern), HandlerFunction { f(it) }) } + fun GET(pattern: String) = RequestPredicates.GET(pattern) + fun HEAD(pattern: String, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.HEAD(pattern), HandlerFunction { f(it) }) } + fun HEAD(pattern: String) = RequestPredicates.HEAD(pattern) + fun POST(pattern: String, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.POST(pattern), HandlerFunction { f(it) }) } + fun POST(pattern: String) = RequestPredicates.POST(pattern) + fun PUT(pattern: String, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.PUT(pattern), HandlerFunction { f(it) }) } + fun PUT(pattern: String) = RequestPredicates.PUT(pattern) + fun PATCH(pattern: String, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.PATCH(pattern), HandlerFunction { f(it) }) } + fun PATCH(pattern: String) = RequestPredicates.PATCH(pattern) + fun DELETE(pattern: String, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.DELETE(pattern), HandlerFunction { f(it) }) } + fun DELETE(pattern: String) = RequestPredicates.DELETE(pattern) + + fun OPTIONS(pattern: String, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.OPTIONS(pattern), HandlerFunction { f(it) }) } + fun OPTIONS(pattern: String) = RequestPredicates.OPTIONS(pattern) + fun accept(mediaType: MediaType, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.accept(mediaType), HandlerFunction { f(it) }) } + fun accept(mediaType: MediaType) = RequestPredicates.accept(mediaType) + fun contentType(mediaType: MediaType, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.contentType(mediaType), HandlerFunction { f(it) }) } + fun contentType(mediaType: MediaType) = RequestPredicates.contentType(mediaType) + fun headers(headerPredicate: (ServerRequest.Headers) -> Boolean, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.headers(headerPredicate), HandlerFunction { f(it) }) } + fun headers(headerPredicate: (ServerRequest.Headers) -> Boolean) = RequestPredicates.headers(headerPredicate) + fun method(httpMethod: HttpMethod, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.method(httpMethod), HandlerFunction { f(it) }) } + fun method(httpMethod: HttpMethod) = RequestPredicates.method(httpMethod) + fun path(pattern: String, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.path(pattern), HandlerFunction { f(it) }) } + fun path(pattern: String) = RequestPredicates.path(pattern) + fun pathExtension(extension: String, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.pathExtension(extension), HandlerFunction { f(it) }) } + fun pathExtension(extension: String) = RequestPredicates.pathExtension(extension) + fun pathExtension(predicate: (String) -> Boolean, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.pathExtension(predicate), HandlerFunction { f(it) }) } + fun pathExtension(predicate: (String) -> Boolean) = RequestPredicates.pathExtension(predicate) + + fun queryParam(name: String, predicate: (String) -> Boolean, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.queryParam(name, predicate), HandlerFunction { f(it) }) } + fun queryParam(name: String, predicate: (String) -> Boolean) = RequestPredicates.queryParam(name, predicate) + + operator fun String.invoke(f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.path(this), HandlerFunction { f(it) }) + } + fun resources(path: String, location: Resource) { routes += RouterFunctions.resources(path, location) } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensionsTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensionsTests.kt index 3750bc4fb75..71644dbd228 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensionsTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensionsTests.kt @@ -18,14 +18,10 @@ package org.springframework.web.reactive.function.server import org.junit.Test import org.springframework.core.io.ClassPathResource -import org.springframework.http.HttpHeaders.ACCEPT -import org.springframework.http.HttpHeaders.CONTENT_TYPE -import org.springframework.http.HttpMethod -import org.springframework.http.HttpMethod.PATCH -import org.springframework.http.HttpMethod.POST +import org.springframework.http.HttpHeaders.* +import org.springframework.http.HttpMethod.* import org.springframework.http.MediaType.* import org.springframework.web.reactive.function.server.MockServerRequest.builder -import org.springframework.web.reactive.function.server.RequestPredicates.* import org.springframework.web.reactive.function.server.ServerResponse.ok import reactor.core.publisher.Mono import reactor.test.StepVerifier @@ -112,14 +108,14 @@ class RouterFunctionExtensionsTests { override fun route(req: ServerRequest) = route(req) { (GET("/foo/") or GET("/foos/")) { handle(req) } - (pathPrefix("/api") and accept(APPLICATION_JSON)).route { + "/api".route { POST("/foo/") { handleFromClass(req) } PUT("/foo/") { handleFromClass(req) } - DELETE("/foo/") { handleFromClass(req) } + "/foo/" { handleFromClass(req) } } accept(APPLICATION_ATOM_XML, ::handle) contentType(APPLICATION_OCTET_STREAM) { handle(req) } - method(HttpMethod.PATCH) { handle(req) } + method(PATCH) { handle(req) } headers({ it.accept().contains(APPLICATION_JSON) }).route { GET("/api/foo/", ::handle) } @@ -142,4 +138,3 @@ class RouterFunctionExtensionsTests { } fun handle(req: ServerRequest) = ok().build() -