Browse Source

Propagate CoroutineContext in CoWebFilter

This provides an elegant and dynamic way to customize the
CoroutineContext in WebFlux with the annotation programming
model.

Closes gh-27522
pull/31186/head
Sébastien Deleuze 2 years ago
parent
commit
b0aa004d9d
  1. 10
      spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt
  2. 32
      spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt
  3. 26
      spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java

10
spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt

@ -17,6 +17,8 @@
package org.springframework.web.server package org.springframework.web.server
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.reactor.awaitSingleOrNull import kotlinx.coroutines.reactor.awaitSingleOrNull
import kotlinx.coroutines.reactor.mono import kotlinx.coroutines.reactor.mono
import reactor.core.publisher.Mono import reactor.core.publisher.Mono
@ -26,6 +28,7 @@ import reactor.core.publisher.Mono
* using coroutines. * using coroutines.
* *
* @author Arjen Poutsma * @author Arjen Poutsma
* @author Sebastien Deleuze
* @since 6.0.5 * @since 6.0.5
*/ */
abstract class CoWebFilter : WebFilter { abstract class CoWebFilter : WebFilter {
@ -34,6 +37,7 @@ abstract class CoWebFilter : WebFilter {
return mono(Dispatchers.Unconfined) { return mono(Dispatchers.Unconfined) {
filter(exchange, object : CoWebFilterChain { filter(exchange, object : CoWebFilterChain {
override suspend fun filter(exchange: ServerWebExchange) { override suspend fun filter(exchange: ServerWebExchange) {
exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key)
chain.filter(exchange).awaitSingleOrNull() chain.filter(exchange).awaitSingleOrNull()
} }
})}.then() })}.then()
@ -47,6 +51,12 @@ abstract class CoWebFilter : WebFilter {
*/ */
protected abstract suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) protected abstract suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain)
companion object {
@JvmField
val COROUTINE_CONTEXT_ATTRIBUTE = CoWebFilter::class.java.getName() + ".context"
}
} }
/** /**

32
spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt

@ -16,6 +16,8 @@
package org.springframework.web.server package org.springframework.web.server
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.mockito.BDDMockito.given import org.mockito.BDDMockito.given
@ -24,9 +26,11 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe
import org.springframework.web.testfixture.server.MockServerWebExchange import org.springframework.web.testfixture.server.MockServerWebExchange
import reactor.core.publisher.Mono import reactor.core.publisher.Mono
import reactor.test.StepVerifier import reactor.test.StepVerifier
import kotlin.coroutines.CoroutineContext
/** /**
* @author Arjen Poutsma * @author Arjen Poutsma
* @author Sebastien Deleuze
*/ */
class CoWebFilterTests { class CoWebFilterTests {
@ -45,6 +49,26 @@ class CoWebFilterTests {
assertThat(exchange.attributes["foo"]).isEqualTo("bar") assertThat(exchange.attributes["foo"]).isEqualTo("bar")
} }
@Test
fun filterWithContext() {
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
val chain = Mockito.mock(WebFilterChain::class.java)
given(chain.filter(exchange)).willReturn(Mono.empty())
val filter = MyCoWebFilterWithContext()
val result = filter.filter(exchange, chain)
StepVerifier.create(result).verifyComplete()
val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext
assertThat(context).isNotNull()
val coroutineName = context[CoroutineName.Key] as CoroutineName
assertThat(coroutineName).isNotNull()
assertThat(coroutineName.name).isEqualTo("foo")
}
} }
@ -54,3 +78,11 @@ private class MyCoWebFilter : CoWebFilter() {
chain.filter(exchange) chain.filter(exchange)
} }
} }
private class MyCoWebFilterWithContext : CoWebFilter() {
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
withContext(CoroutineName("foo")) {
chain.filter(exchange)
}
}
}

26
spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java

@ -26,6 +26,7 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Stream; import java.util.stream.Stream;
import kotlin.coroutines.CoroutineContext;
import kotlin.reflect.KFunction; import kotlin.reflect.KFunction;
import kotlin.reflect.KParameter; import kotlin.reflect.KParameter;
import kotlin.reflect.jvm.KCallablesJvm; import kotlin.reflect.jvm.KCallablesJvm;
@ -48,6 +49,7 @@ import org.springframework.validation.method.MethodValidator;
import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.server.CoWebFilter;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
/** /**
@ -152,7 +154,7 @@ public class InvocableHandlerMethod extends HandlerMethod {
* @param providedArgs optional list of argument values to match by type * @param providedArgs optional list of argument values to match by type
* @return a Mono with a {@link HandlerResult} * @return a Mono with a {@link HandlerResult}
*/ */
@SuppressWarnings({"KotlinInternalInJava", "unchecked"}) @SuppressWarnings("unchecked")
public Mono<HandlerResult> invoke( public Mono<HandlerResult> invoke(
ServerWebExchange exchange, BindingContext bindingContext, Object... providedArgs) { ServerWebExchange exchange, BindingContext bindingContext, Object... providedArgs) {
@ -167,12 +169,7 @@ public class InvocableHandlerMethod extends HandlerMethod {
boolean isSuspendingFunction = KotlinDetector.isSuspendingFunction(method); boolean isSuspendingFunction = KotlinDetector.isSuspendingFunction(method);
try { try {
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) { if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
if (isSuspendingFunction) { value = KotlinDelegate.invokeFunction(method, getBean(), args, isSuspendingFunction, exchange);
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
}
else {
value = KotlinDelegate.invokeFunction(method, getBean(), args);
}
} }
else { else {
value = method.invoke(getBean(), args); value = method.invoke(getBean(), args);
@ -297,7 +294,19 @@ public class InvocableHandlerMethod extends HandlerMethod {
@Nullable @Nullable
@SuppressWarnings("deprecation") @SuppressWarnings("deprecation")
public static Object invokeFunction(Method method, Object target, Object[] args) { public static Object invokeFunction(Method method, Object target, Object[] args, boolean isSuspendingFunction,
ServerWebExchange exchange) {
if (isSuspendingFunction) {
Object coroutineContext = exchange.getAttribute(CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE);
if (coroutineContext == null) {
return CoroutinesUtils.invokeSuspendingFunction(method, target, args);
}
else {
return CoroutinesUtils.invokeSuspendingFunction((CoroutineContext) coroutineContext, method, target, args);
}
}
else {
KFunction<?> function = Objects.requireNonNull(ReflectJvmMapping.getKotlinFunction(method)); KFunction<?> function = Objects.requireNonNull(ReflectJvmMapping.getKotlinFunction(method));
if (method.isAccessible() && !KCallablesJvm.isAccessible(function)) { if (method.isAccessible() && !KCallablesJvm.isAccessible(function)) {
KCallablesJvm.setAccessible(function, true); KCallablesJvm.setAccessible(function, true);
@ -318,5 +327,6 @@ public class InvocableHandlerMethod extends HandlerMethod {
return function.callBy(argMap); return function.callBy(argMap);
} }
} }
}
} }

Loading…
Cancel
Save