From a01c6d57bb23bf885bed07060c98ffb8828521ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Wed, 13 Dec 2023 15:54:23 +0100 Subject: [PATCH] Inherit parent context in coRouter DSL This commit also allows context override, as it is useful for the nested router use case. Closes gh-31831 --- .../function/server/CoRouterFunctionDsl.kt | 5 +- .../server/CoRouterFunctionDslTests.kt | 69 +++++++++++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt index eb63a9f9339..62febc28f2d 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt @@ -144,7 +144,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct * @see RouterFunctions.nest */ fun RequestPredicate.nest(r: (CoRouterFunctionDsl.() -> Unit)) { - builder.add(nest(this, CoRouterFunctionDsl(r).build())) + builder.add(nest(this, CoRouterFunctionDsl(r).also { it.contextProvider = contextProvider }.build())) } @@ -628,9 +628,6 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct * @since 6.1 */ fun context(provider: suspend (ServerRequest) -> CoroutineContext) { - if (this.contextProvider != null) { - throw IllegalStateException("The Coroutine context provider should not be defined more than once") - } this.contextProvider = provider } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt index c2e48f4883a..e0a75566b72 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt @@ -193,6 +193,45 @@ class CoRouterFunctionDslTests { .verifyComplete() } + @Test + fun nestedContextProvider() { + val mockRequest = get("https://example.com/nested/") + .header("Custom-Header", "foo") + .build() + val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList()) + StepVerifier.create(nestedRouterWithContextProvider.route(request).flatMap { it.handle(request) }) + .expectNextMatches { response -> + response.headers().getFirst("context")!!.contains("foo") + } + .verifyComplete() + } + + @Test + fun nestedContextProviderWithOverride() { + val mockRequest = get("https://example.com/nested/") + .header("Custom-Header", "foo") + .build() + val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList()) + StepVerifier.create(nestedRouterWithContextProviderOverride.route(request).flatMap { it.handle(request) }) + .expectNextMatches { response -> + response.headers().getFirst("context")!!.contains("foo") + } + .verifyComplete() + } + + @Test + fun doubleNestedContextProvider() { + val mockRequest = get("https://example.com/nested/nested/") + .header("Custom-Header", "foo") + .build() + val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList()) + StepVerifier.create(nestedRouterWithContextProvider.route(request).flatMap { it.handle(request) }) + .expectNextMatches { response -> + response.headers().getFirst("context")!!.contains("foo") + } + .verifyComplete() + } + @Test fun contextProviderAndFilter() { val mockRequest = get("https://example.com/") @@ -323,6 +362,36 @@ class CoRouterFunctionDslTests { } } + private val nestedRouterWithContextProvider = coRouter { + context { + CoroutineName(it.headers().firstHeader("Custom-Header")!!) + } + "/nested".nest { + GET("/") { + ok().header("context", currentCoroutineContext().toString()).buildAndAwait() + } + "/nested".nest { + GET("/") { + ok().header("context", currentCoroutineContext().toString()).buildAndAwait() + } + } + } + } + + private val nestedRouterWithContextProviderOverride = coRouter { + context { + CoroutineName("parent-context") + } + "/nested".nest { + context { + CoroutineName(it.headers().firstHeader("Custom-Header")!!) + } + GET("/") { + ok().header("context", currentCoroutineContext().toString()).buildAndAwait() + } + } + } + private val routerWithoutContext = coRouter { GET("/") { ok().header("context", currentCoroutineContext().toString()).buildAndAwait()