diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java index 7eae733b4d5..cef6d945750 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java @@ -129,7 +129,16 @@ public interface ServerRequest { */ default Optional queryParam(String name) { List queryParams = this.queryParams(name); - return (!queryParams.isEmpty() ? Optional.of(queryParams.get(0)) : Optional.empty()); + if (queryParams.isEmpty()) { + return Optional.empty(); + } + else { + String value = queryParams.get(0); + if (value == null) { + value = ""; + } + return Optional.of(value); + } } /** diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestTests.java index b6efb6a73b4..303cb6dc029 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestTests.java @@ -43,16 +43,11 @@ import org.springframework.http.HttpRange; import org.springframework.http.MediaType; import org.springframework.http.codec.DecoderHttpMessageReader; import org.springframework.http.codec.HttpMessageReader; -import org.springframework.http.server.reactive.ServerHttpRequest; -import org.springframework.http.server.reactive.ServerHttpResponse; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; -import org.springframework.web.server.ServerWebExchange; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerWebExchange; import org.springframework.web.server.UnsupportedMediaTypeStatusException; -import org.springframework.web.server.WebSession; import static org.junit.Assert.*; -import static org.mockito.Mockito.*; import static org.springframework.web.reactive.function.BodyExtractors.toMono; /** @@ -60,92 +55,96 @@ import static org.springframework.web.reactive.function.BodyExtractors.toMono; */ public class DefaultServerRequestTests { - private ServerHttpRequest mockRequest; - - private ServerWebExchange mockExchange; - List> messageReaders; - private DefaultServerRequest defaultRequest; - @Before public void createMocks() { - mockRequest = mock(ServerHttpRequest.class); - ServerHttpResponse mockResponse = mock(ServerHttpResponse.class); - - mockExchange = mock(ServerWebExchange.class); - when(mockExchange.getRequest()).thenReturn(mockRequest); - when(mockExchange.getResponse()).thenReturn(mockResponse); - this.messageReaders = Collections.>singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true))); - - defaultRequest = new DefaultServerRequest(mockExchange, messageReaders); } @Test public void method() throws Exception { HttpMethod method = HttpMethod.HEAD; - when(mockRequest.getMethod()).thenReturn(method); + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(method, "http://example.com").build(); + DefaultServerRequest request = new DefaultServerRequest(mockRequest.toExchange(), messageReaders); - assertEquals(method, defaultRequest.method()); + assertEquals(method, request.method()); } @Test public void uri() throws Exception { URI uri = URI.create("https://example.com"); - when(mockRequest.getURI()).thenReturn(uri); - assertEquals(uri, defaultRequest.uri()); + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build(); + DefaultServerRequest request = new DefaultServerRequest(mockRequest.toExchange(), messageReaders); + + assertEquals(uri, request.uri()); } @Test public void attribute() throws Exception { - when(mockExchange.getAttribute("foo")).thenReturn(Optional.of("bar")); + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com").build(); + MockServerWebExchange exchange = new MockServerWebExchange(mockRequest); + exchange.getAttributes().put("foo", "bar"); + + DefaultServerRequest request = new DefaultServerRequest(exchange, messageReaders); - assertEquals(Optional.of("bar"), defaultRequest.attribute("foo")); + assertEquals(Optional.of("bar"), request.attribute("foo")); } @Test public void queryParams() throws Exception { - MultiValueMap queryParams = new LinkedMultiValueMap<>(); - queryParams.set("foo", "bar"); - when(mockRequest.getQueryParams()).thenReturn(queryParams); + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com?foo=bar").build(); + DefaultServerRequest request = new DefaultServerRequest(mockRequest.toExchange(), messageReaders); - assertEquals(Optional.of("bar"), defaultRequest.queryParam("foo")); + assertEquals(Optional.of("bar"), request.queryParam("foo")); + } + + @Test + public void emptyQueryParam() throws Exception { + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com?foo").build(); + DefaultServerRequest request = new DefaultServerRequest(mockRequest.toExchange(), messageReaders); + + assertEquals(Optional.of(""), request.queryParam("foo")); } @Test public void pathVariable() throws Exception { + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com").build(); + MockServerWebExchange exchange = new MockServerWebExchange(mockRequest); Map pathVariables = Collections.singletonMap("foo", "bar"); - when(mockExchange.getAttribute(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE)).thenReturn(Optional.of(pathVariables)); + exchange.getAttributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, pathVariables); + + DefaultServerRequest request = new DefaultServerRequest(exchange, messageReaders); - assertEquals("bar", defaultRequest.pathVariable("foo")); + assertEquals("bar", request.pathVariable("foo")); } + @Test(expected = IllegalArgumentException.class) public void pathVariableNotFound() throws Exception { + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com").build(); + MockServerWebExchange exchange = new MockServerWebExchange(mockRequest); Map pathVariables = Collections.singletonMap("foo", "bar"); - when(mockExchange.getAttribute(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE)).thenReturn(Optional.of(pathVariables)); + exchange.getAttributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, pathVariables); - assertEquals("bar", defaultRequest.pathVariable("baz")); + DefaultServerRequest request = new DefaultServerRequest(exchange, messageReaders); + + request.pathVariable("baz"); } @Test public void pathVariables() throws Exception { + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com").build(); + MockServerWebExchange exchange = new MockServerWebExchange(mockRequest); Map pathVariables = Collections.singletonMap("foo", "bar"); - when(mockExchange.getAttribute(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE)).thenReturn(Optional.of(pathVariables)); - - assertEquals(pathVariables, defaultRequest.pathVariables()); - } + exchange.getAttributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, pathVariables); - @Test - public void session() throws Exception { - WebSession session = mock(WebSession.class); - when(mockExchange.getSession()).thenReturn(Mono.just(session)); + DefaultServerRequest request = new DefaultServerRequest(exchange, messageReaders); - assertEquals(session, defaultRequest.session().block()); + assertEquals(pathVariables, request.pathVariables()); } @Test @@ -165,9 +164,11 @@ public class DefaultServerRequestTests { List range = Collections.singletonList(HttpRange.createByteRange(0, 42)); httpHeaders.setRange(range); - when(mockRequest.getHeaders()).thenReturn(httpHeaders); + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com?foo=bar"). + headers(httpHeaders).build(); + DefaultServerRequest request = new DefaultServerRequest(mockRequest.toExchange(), messageReaders); - ServerRequest.Headers headers = defaultRequest.headers(); + ServerRequest.Headers headers = request.headers(); assertEquals(accept, headers.accept()); assertEquals(acceptCharset, headers.acceptCharset()); assertEquals(OptionalLong.of(contentLength), headers.contentLength()); @@ -184,10 +185,12 @@ public class DefaultServerRequestTests { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockRequest.getHeaders()).thenReturn(httpHeaders); - when(mockRequest.getBody()).thenReturn(body); - Mono resultMono = defaultRequest.body(toMono(String.class)); + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com?foo=bar"). + headers(httpHeaders).body(body); + DefaultServerRequest request = new DefaultServerRequest(mockRequest.toExchange(), messageReaders); + + Mono resultMono = request.body(toMono(String.class)); assertEquals("foo", resultMono.block()); } @@ -200,10 +203,11 @@ public class DefaultServerRequestTests { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockRequest.getHeaders()).thenReturn(httpHeaders); - when(mockRequest.getBody()).thenReturn(body); + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com?foo=bar"). + headers(httpHeaders).body(body); + DefaultServerRequest request = new DefaultServerRequest(mockRequest.toExchange(), messageReaders); - Mono resultMono = defaultRequest.bodyToMono(String.class); + Mono resultMono = request.bodyToMono(String.class); assertEquals("foo", resultMono.block()); } @@ -216,12 +220,12 @@ public class DefaultServerRequestTests { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockRequest.getHeaders()).thenReturn(httpHeaders); - when(mockRequest.getBody()).thenReturn(body); + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com?foo=bar"). + headers(httpHeaders).body(body); + DefaultServerRequest request = new DefaultServerRequest(mockRequest.toExchange(), messageReaders); - Flux resultFlux = defaultRequest.bodyToFlux(String.class); - Mono> result = resultFlux.collectList(); - assertEquals(Collections.singletonList("foo"), result.block()); + Flux resultFlux = request.bodyToFlux(String.class); + assertEquals(Collections.singletonList("foo"), resultFlux.collectList().block()); } @Test @@ -233,16 +237,14 @@ public class DefaultServerRequestTests { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockRequest.getHeaders()).thenReturn(httpHeaders); - when(mockRequest.getBody()).thenReturn(body); - + MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, "http://example.com?foo=bar"). + headers(httpHeaders).body(body); this.messageReaders = Collections.emptyList(); - this.defaultRequest = new DefaultServerRequest(mockExchange, messageReaders); + DefaultServerRequest request = new DefaultServerRequest(mockRequest.toExchange(), messageReaders); - Flux resultFlux = defaultRequest.bodyToFlux(String.class); + Flux resultFlux = request.bodyToFlux(String.class); StepVerifier.create(resultFlux) .expectError(UnsupportedMediaTypeStatusException.class) .verify(); } - }