Browse Source

Provide orNull extensions for WebFlux ServerRequest

Closes gh-23761
pull/23992/head
Sébastien Deleuze 6 years ago
parent
commit
6fa9871a70
  1. 56
      spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensions.kt
  2. 90
      spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensionsTests.kt

56
spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensions.kt

@ -21,11 +21,14 @@ import kotlinx.coroutines.reactive.awaitFirstOrNull
import kotlinx.coroutines.reactive.awaitSingle import kotlinx.coroutines.reactive.awaitSingle
import kotlinx.coroutines.reactive.asFlow import kotlinx.coroutines.reactive.asFlow
import org.springframework.core.ParameterizedTypeReference import org.springframework.core.ParameterizedTypeReference
import org.springframework.http.MediaType
import org.springframework.http.codec.multipart.Part import org.springframework.http.codec.multipart.Part
import org.springframework.util.CollectionUtils
import org.springframework.util.MultiValueMap import org.springframework.util.MultiValueMap
import org.springframework.web.server.WebSession import org.springframework.web.server.WebSession
import reactor.core.publisher.Flux import reactor.core.publisher.Flux
import reactor.core.publisher.Mono import reactor.core.publisher.Mono
import java.net.InetSocketAddress
import java.security.Principal import java.security.Principal
/** /**
@ -112,3 +115,56 @@ suspend fun ServerRequest.awaitPrincipal(): Principal? =
*/ */
suspend fun ServerRequest.awaitSession(): WebSession = suspend fun ServerRequest.awaitSession(): WebSession =
session().awaitSingle() session().awaitSingle()
/**
* Nullable variant of [ServerRequest.remoteAddress]
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
fun ServerRequest.remoteAddressOrNull(): InetSocketAddress? = remoteAddress().orElse(null)
/**
* Nullable variant of [ServerRequest.attribute]
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
fun ServerRequest.attributeOrNull(name: String): Any? = attributes()[name]
/**
* Nullable variant of [ServerRequest.queryParam]
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
fun ServerRequest.queryParamOrNull(name: String): String? {
val queryParamValues = queryParams()[name]
return if (CollectionUtils.isEmpty(queryParamValues)) {
null
} else {
var value: String? = queryParamValues!![0]
if (value == null) {
value = ""
}
value
}
}
/**
* Nullable variant of [ServerRequest.Headers.contentLength]
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
fun ServerRequest.Headers.contentLengthOrNull(): Long? =
contentLength().run { if (isPresent) asLong else null }
/**
* Nullable variant of [ServerRequest.Headers.contentType]
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
fun ServerRequest.Headers.contentTypeOrNull(): MediaType? =
contentType().orElse(null)

90
spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensionsTests.kt

@ -23,11 +23,15 @@ import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.springframework.core.ParameterizedTypeReference import org.springframework.core.ParameterizedTypeReference
import org.springframework.http.MediaType
import org.springframework.http.codec.multipart.Part import org.springframework.http.codec.multipart.Part
import org.springframework.util.CollectionUtils
import org.springframework.util.MultiValueMap import org.springframework.util.MultiValueMap
import org.springframework.web.server.WebSession import org.springframework.web.server.WebSession
import reactor.core.publisher.Mono import reactor.core.publisher.Mono
import java.net.InetSocketAddress
import java.security.Principal import java.security.Principal
import java.util.*
/** /**
* Mock object based tests for [ServerRequest] Kotlin extensions. * Mock object based tests for [ServerRequest] Kotlin extensions.
@ -38,6 +42,8 @@ class ServerRequestExtensionsTests {
val request = mockk<ServerRequest>(relaxed = true) val request = mockk<ServerRequest>(relaxed = true)
val headers = mockk<ServerRequest.Headers>(relaxed = true)
@Test @Test
fun `bodyToMono with reified type parameters`() { fun `bodyToMono with reified type parameters`() {
request.bodyToMono<List<Foo>>() request.bodyToMono<List<Foo>>()
@ -108,6 +114,90 @@ class ServerRequestExtensionsTests {
} }
} }
@Test
fun `remoteAddressOrNull with value`() {
val remoteAddress = InetSocketAddress(1234)
every { request.remoteAddress() } returns Optional.of(remoteAddress)
assertThat(remoteAddress).isEqualTo(request.remoteAddressOrNull())
verify { request.remoteAddress() }
}
@Test
fun `remoteAddressOrNull with null`() {
every { request.remoteAddress() } returns Optional.empty()
assertThat(request.remoteAddressOrNull()).isNull()
verify { request.remoteAddress() }
}
@Test
fun `attributeOrNull with value`() {
every { request.attributes() } returns mapOf("foo" to "bar")
assertThat(request.attributeOrNull("foo")).isEqualTo("bar")
verify { request.attributes() }
}
@Test
fun `attributeOrNull with null`() {
every { request.attributes() } returns mapOf("foo" to "bar")
assertThat(request.attributeOrNull("baz")).isNull()
verify { request.attributes() }
}
@Test
fun `queryParamOrNull with value`() {
every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf("bar")))
assertThat(request.queryParamOrNull("foo")).isEqualTo("bar")
verify { request.queryParams() }
}
@Test
fun `queryParamOrNull with values`() {
every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf("bar", "bar")))
assertThat(request.queryParamOrNull("foo")).isEqualTo("bar")
verify { request.queryParams() }
}
@Test
fun `queryParamOrNull with null value`() {
every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf(null)))
assertThat(request.queryParamOrNull("foo")).isEqualTo("")
verify { request.queryParams() }
}
@Test
fun `queryParamOrNull with null`() {
every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf("bar")))
assertThat(request.queryParamOrNull("baz")).isNull()
verify { request.queryParams() }
}
@Test
fun `contentLengthOrNull with value`() {
every { headers.contentLength() } returns OptionalLong.of(123)
assertThat(headers.contentLengthOrNull()).isEqualTo(123)
verify { headers.contentLength() }
}
@Test
fun `contentLengthOrNull with null`() {
every { headers.contentLength() } returns OptionalLong.empty()
assertThat(headers.contentLengthOrNull()).isNull()
verify { headers.contentLength() }
}
@Test
fun `contentTypeOrNull with value`() {
every { headers.contentType() } returns Optional.of(MediaType.APPLICATION_JSON)
assertThat(headers.contentTypeOrNull()).isEqualTo(MediaType.APPLICATION_JSON)
verify { headers.contentType() }
}
@Test
fun `contentTypeOrNull with null`() {
every { headers.contentType() } returns Optional.empty()
assertThat(headers.contentTypeOrNull()).isNull()
verify { headers.contentType() }
}
class Foo class Foo
} }

Loading…
Cancel
Save