diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensions.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensions.kt index 825c5256944..102b165119a 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensions.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensions.kt @@ -9,21 +9,33 @@ import reactor.core.publisher.Mono * Provide a routing DSL for [RouterFunctions] and [RouterFunction] in order to be able to * write idiomatic Kotlin code as below: * - * * ```kotlin - * fun route(request: ServerRequest) = route(request) { - * accept(TEXT_HTML).apply { - * (GET("/user/") or GET("/users/")) { findAllView() } - * GET("/user/{login}", ::findViewById) - * } - * accept(APPLICATION_JSON).apply { - * (GET("/api/user/") or GET("/api/users/")) { findAll() } - * POST("/api/user/", ::create) - * POST("/api/user/{login}", ::findOne) + * ```kotlin + * import org.springframework.web.reactive.function.server.RequestPredicates.* + * ... + * + * @Controller + * class FooController : RouterFunction { + * + * override fun route(req: ServerRequest) = route(req) { + * html().apply { + * (GET("/user/") or GET("/users/")) { findAllView() } + * GET("/user/{login}", this@FooController::findViewById) + * } + * json().apply { + * (GET("/api/user/") or GET("/api/users/")) { findAll() } + * POST("/api/user/", this@FooController::create) + * } * } + * + * fun findAllView() = ... + * fun findViewById(req: ServerRequest) = ... + * fun findAll() = ... + * fun create(req: ServerRequest) = * } * ``` * * @since 5.0 + * @see Kotlin issue about supporting ::foo for member functions * @author Sebastien Deleuze * @author Yevhenii Melnyk */ @@ -46,7 +58,7 @@ class RouterDsl { } fun GET(pattern: String, f: (ServerRequest) -> Mono) { - routes += RouterFunctions.route(RequestPredicates.GET(pattern), HandlerFunction { f(it) } ) + routes += RouterFunctions.route(RequestPredicates.GET(pattern), HandlerFunction { f(it) }) } fun HEAD(pattern: String, f: (ServerRequest) -> Mono) { @@ -81,7 +93,7 @@ class RouterDsl { routes += RouterFunctions.route(RequestPredicates.contentType(mediaType), HandlerFunction { f(it) }) } - fun headers(headerPredicate: (ServerRequest.Headers)->Boolean, f: (ServerRequest) -> Mono) { + fun headers(headerPredicate: (ServerRequest.Headers) -> Boolean, f: (ServerRequest) -> Mono) { routes += RouterFunctions.route(RequestPredicates.headers(headerPredicate), HandlerFunction { f(it) }) } @@ -93,15 +105,38 @@ class RouterDsl { routes += RouterFunctions.route(RequestPredicates.path(pattern), HandlerFunction { f(it) }) } + fun pathExtension(extension: String, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.pathExtension(extension), HandlerFunction { f(it) }) + } + + fun pathExtension(predicate: (String) -> Boolean, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.pathExtension(predicate), HandlerFunction { f(it) }) + } + + fun queryParam(name: String, predicate: (String) -> Boolean, f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.queryParam(name, predicate), HandlerFunction { f(it) }) + } + + fun json(f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.json(), HandlerFunction { f(it) }) + } + + fun html(f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.html(), HandlerFunction { f(it) }) + } + + fun xml(f: (ServerRequest) -> Mono) { + routes += RouterFunctions.route(RequestPredicates.xml(), HandlerFunction { f(it) }) + } + fun resources(path: String, location: Resource) { - routes += RouterFunctions.resources(path, location) + routes += RouterFunctions.resources(path, location) } fun resources(lookupFunction: (ServerRequest) -> Mono) { - routes += RouterFunctions.resources(lookupFunction) + routes += RouterFunctions.resources(lookupFunction) } - @Suppress("UNCHECKED_CAST") fun router(): RouterFunction { return routes().reduce(RouterFunction<*>::and) as RouterFunction diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensionsTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensionsTests.kt index f8995fc5e2e..f03a842ada5 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensionsTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionExtensionsTests.kt @@ -97,10 +97,13 @@ class RouterFunctionExtensionsTests { override fun route(req: ServerRequest) = route(req) { (GET("/foo/") or GET("/foos/")) { handle(req) } accept(APPLICATION_JSON).apply { - POST("/api/foo/", ::handle) - PUT("/api/foo/", ::handle) + POST("/api/foo/", this@FooController::handleFromClass) + PUT("/api/foo/") { handleFromClass(req) } DELETE("/api/foo/", ::handle) } + html().apply { + GET("/page", this@FooController::handleFromClass) + } accept(APPLICATION_ATOM_XML, ::handle) contentType(APPLICATION_OCTET_STREAM) { handle(req) } method(HttpMethod.PATCH) { handle(req) } @@ -120,8 +123,10 @@ class RouterFunctionExtensionsTests { } path("/baz") { handle(req) } } + + fun handleFromClass(req: ServerRequest) = ok().build() } } -private fun handle(req: ServerRequest) = ok().build() +fun handle(req: ServerRequest) = ok().build()