|
|
|
@ -26,6 +26,7 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe |
|
|
|
import org.springframework.web.testfixture.server.MockServerWebExchange |
|
|
|
import org.springframework.web.testfixture.server.MockServerWebExchange |
|
|
|
import reactor.core.publisher.Mono |
|
|
|
import reactor.core.publisher.Mono |
|
|
|
import reactor.test.StepVerifier |
|
|
|
import reactor.test.StepVerifier |
|
|
|
|
|
|
|
import kotlin.coroutines.AbstractCoroutineContextElement |
|
|
|
import kotlin.coroutines.CoroutineContext |
|
|
|
import kotlin.coroutines.CoroutineContext |
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
/** |
|
|
|
@ -44,10 +45,24 @@ class CoWebFilterTests { |
|
|
|
val filter = MyCoWebFilter() |
|
|
|
val filter = MyCoWebFilter() |
|
|
|
val result = filter.filter(exchange, chain) |
|
|
|
val result = filter.filter(exchange, chain) |
|
|
|
|
|
|
|
|
|
|
|
StepVerifier.create(result) |
|
|
|
StepVerifier.create(result).verifyComplete() |
|
|
|
.verifyComplete() |
|
|
|
|
|
|
|
|
|
|
|
assertThat(exchange.attributes["foo"]).isEqualTo("bar") |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
|
|
|
fun multipleFilters() { |
|
|
|
|
|
|
|
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val chain = Mockito.mock(WebFilterChain::class.java) |
|
|
|
|
|
|
|
given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilter().filter(exchange,chain) }.willReturn(Mono.empty()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val result = MyCoWebFilter().filter(exchange, chain) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
StepVerifier.create(result).verifyComplete() |
|
|
|
|
|
|
|
|
|
|
|
assertThat(exchange.attributes["foo"]).isEqualTo("bar") |
|
|
|
assertThat(exchange.attributes["foo"]).isEqualTo("bar") |
|
|
|
|
|
|
|
assertThat(exchange.attributes["foofoo"]).isEqualTo("barbar") |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
@Test |
|
|
|
@ -69,6 +84,28 @@ class CoWebFilterTests { |
|
|
|
assertThat(coroutineName.name).isEqualTo("foo") |
|
|
|
assertThat(coroutineName.name).isEqualTo("foo") |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
|
|
|
fun multipleFiltersWithContext() { |
|
|
|
|
|
|
|
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val chain = Mockito.mock(WebFilterChain::class.java) |
|
|
|
|
|
|
|
given(chain.filter(exchange)).willAnswer { MyOtherCoWebFilterWithContext().filter(exchange,chain) }.willReturn(Mono.empty()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val filter = MyCoWebFilterWithContext() |
|
|
|
|
|
|
|
val result = filter.filter(exchange, chain) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
StepVerifier.create(result).verifyComplete() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext |
|
|
|
|
|
|
|
assertThat(context).isNotNull() |
|
|
|
|
|
|
|
val coroutineName = context[CoroutineName.Key] as CoroutineName |
|
|
|
|
|
|
|
assertThat(coroutineName).isNotNull() |
|
|
|
|
|
|
|
assertThat(coroutineName.name).isEqualTo("foo") |
|
|
|
|
|
|
|
val coroutineDescription = context[CoroutineDescription.Key] as CoroutineDescription |
|
|
|
|
|
|
|
assertThat(coroutineDescription).isNotNull() |
|
|
|
|
|
|
|
assertThat(coroutineDescription.description).isEqualTo("foofoo") |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -79,6 +116,13 @@ private class MyCoWebFilter : CoWebFilter() { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private class MyOtherCoWebFilter : CoWebFilter() { |
|
|
|
|
|
|
|
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) { |
|
|
|
|
|
|
|
exchange.attributes["foofoo"] = "barbar" |
|
|
|
|
|
|
|
chain.filter(exchange) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private class MyCoWebFilterWithContext : CoWebFilter() { |
|
|
|
private class MyCoWebFilterWithContext : CoWebFilter() { |
|
|
|
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) { |
|
|
|
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) { |
|
|
|
withContext(CoroutineName("foo")) { |
|
|
|
withContext(CoroutineName("foo")) { |
|
|
|
@ -86,3 +130,19 @@ private class MyCoWebFilterWithContext : CoWebFilter() { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private class MyOtherCoWebFilterWithContext : CoWebFilter() { |
|
|
|
|
|
|
|
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) { |
|
|
|
|
|
|
|
withContext(CoroutineDescription("foofoo")) { |
|
|
|
|
|
|
|
chain.filter(exchange) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data class CoroutineDescription(val description: String) : AbstractCoroutineContextElement(CoroutineDescription) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
companion object Key : CoroutineContext.Key<CoroutineDescription> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
override fun toString(): String = "CoroutineDescription($description)" |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|