diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java index 34ef921c7e4..9598c018004 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java @@ -24,6 +24,7 @@ import java.net.InetSocketAddress; import java.net.URI; import java.nio.charset.StandardCharsets; import java.security.Principal; +import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; @@ -47,6 +48,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -76,18 +78,25 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder { private final Map attributes = new LinkedHashMap<>(); + private final MultiValueMap params = new LinkedMultiValueMap<>(); + + @Nullable + private InetSocketAddress remoteAddress; + private byte[] body = new byte[0]; public DefaultServerRequestBuilder(ServerRequest other) { Assert.notNull(other, "ServerRequest must not be null"); this.servletRequest = other.servletRequest(); - this.messageConverters = other.messageConverters(); + this.messageConverters = new ArrayList<>(other.messageConverters()); this.methodName = other.methodName(); this.uri = other.uri(); headers(headers -> headers.addAll(other.headers().asHttpHeaders())); cookies(cookies -> cookies.addAll(other.cookies())); attributes(attributes -> attributes.putAll(other.attributes())); + params(params -> params.addAll(other.params())); + this.remoteAddress = other.remoteAddress().orElse(null); } @Override @@ -156,10 +165,31 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder { return this; } + @Override + public ServerRequest.Builder param(String name, String... values) { + for (String value : values) { + this.params.add(name, value); + } + return this; + } + + @Override + public ServerRequest.Builder params(Consumer> paramsConsumer) { + paramsConsumer.accept(this.params); + return this; + } + + @Override + public ServerRequest.Builder remoteAddress(InetSocketAddress remoteAddress) { + this.remoteAddress = remoteAddress; + return this; + } + + @Override public ServerRequest build() { - return new BuiltServerRequest(this.servletRequest, this.methodName, this.uri, - this.headers, this.cookies, this.attributes, this.body, this.messageConverters); + return new BuiltServerRequest(this.servletRequest, this.methodName, this.uri, this.headers, this.cookies, + this.attributes, this.params, this.remoteAddress, this.body, this.messageConverters); } @@ -181,17 +211,24 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder { private final List> messageConverters; + private final MultiValueMap params; + + @Nullable + private final InetSocketAddress remoteAddress; + public BuiltServerRequest(HttpServletRequest servletRequest, String methodName, URI uri, HttpHeaders headers, MultiValueMap cookies, - Map attributes, byte[] body, - List> messageConverters) { + Map attributes, MultiValueMap params, + @Nullable InetSocketAddress remoteAddress, byte[] body, List> messageConverters) { this.servletRequest = servletRequest; this.methodName = methodName; this.uri = uri; - this.headers = headers; - this.cookies = cookies; - this.attributes = attributes; + this.headers = new HttpHeaders(headers); + this.cookies = new LinkedMultiValueMap<>(cookies); + this.attributes = new LinkedHashMap<>(attributes); + this.params = new LinkedMultiValueMap<>(params); + this.remoteAddress = remoteAddress; this.body = body; this.messageConverters = messageConverters; } @@ -231,7 +268,7 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder { @Override public Optional remoteAddress() { - return Optional.empty(); + return Optional.ofNullable(this.remoteAddress); } @Override @@ -280,7 +317,7 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder { @Override public MultiValueMap params() { - return new LinkedMultiValueMap<>(); + return this.params; } @Override diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerRequest.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerRequest.java index 00beac0420c..6bbb3a43391 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerRequest.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerRequest.java @@ -553,6 +553,32 @@ public interface ServerRequest { */ Builder attributes(Consumer> attributesConsumer); + /** + * Add a parameter with the given name and value. + * @param name the parameter name + * @param values the parameter value(s) + * @return this builder + */ + Builder param(String name, String... values); + + /** + * Manipulate this request's parameters with the given consumer. + *

The map provided to the consumer is "live", so that the consumer can be used to + * {@linkplain MultiValueMap#set(Object, Object) overwrite} existing cookies, + * {@linkplain MultiValueMap#remove(Object) remove} cookies, or use any of the other + * {@link MultiValueMap} methods. + * @param paramsConsumer a function that consumes the parameters map + * @return this builder + */ + Builder params(Consumer> paramsConsumer); + + /** + * Set the remote address of the request. + * @param remoteAddress the remote address + * @return this builder + */ + Builder remoteAddress(InetSocketAddress remoteAddress); + /** * Build the request. * @return the built request diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestBuilderTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestBuilderTests.java index cc20f5c98ff..254f84c67ab 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestBuilderTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestBuilderTests.java @@ -17,6 +17,7 @@ package org.springframework.web.servlet.function; import java.io.IOException; +import java.net.InetSocketAddress; import java.util.Collections; import java.util.List; @@ -29,6 +30,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.web.servlet.handler.PathPatternsTestUtils; +import org.springframework.web.testfixture.servlet.MockCookie; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import static org.assertj.core.api.Assertions.assertThat; @@ -46,30 +48,45 @@ class DefaultServerRequestBuilderTests { void from() throws ServletException, IOException { MockHttpServletRequest request = PathPatternsTestUtils.initRequest("POST", "https://example.com", true); request.addHeader("foo", "bar"); + request.setCookies(new MockCookie("foo", "bar")); + request.setAttribute("foo", "bar"); + request.addParameter("foo", "bar"); + request.setRemoteHost("127.0.0.1"); + request.setRemotePort(80); ServerRequest other = ServerRequest.create(request, messageConverters); ServerRequest result = ServerRequest.from(other) .method(HttpMethod.HEAD) - .header("foo", "bar") - .headers(httpHeaders -> httpHeaders.set("baz", "qux")) - .cookie("foo", "bar") - .cookies(cookies -> cookies.set("baz", new Cookie("baz", "qux"))) - .attribute("foo", "bar") - .attributes(attributes -> attributes.put("baz", "qux")) + .header("baz", "qux") + .headers(httpHeaders -> httpHeaders.set("quux", "quuz")) + .cookie("baz", "qux") + .cookies(cookies -> cookies.set("quux", new Cookie("quux", "quuz"))) + .attribute("baz", "qux") + .attributes(attributes -> attributes.put("quux", "quuz")) + .param("baz", "qux") + .params(params -> params.set("quux", "quuz")) .body("baz") .build(); assertThat(result.method()).isEqualTo(HttpMethod.HEAD); - assertThat(result.headers().asHttpHeaders().size()).isEqualTo(2); assertThat(result.headers().asHttpHeaders().getFirst("foo")).isEqualTo("bar"); assertThat(result.headers().asHttpHeaders().getFirst("baz")).isEqualTo("qux"); - assertThat(result.cookies().size()).isEqualTo(2); + assertThat(result.headers().asHttpHeaders().getFirst("quux")).isEqualTo("quuz"); + assertThat(result.cookies().getFirst("foo").getValue()).isEqualTo("bar"); assertThat(result.cookies().getFirst("baz").getValue()).isEqualTo("qux"); - assertThat(result.attributes().size()).isEqualTo(other.attributes().size() + 2); + assertThat(result.cookies().getFirst("quux").getValue()).isEqualTo("quuz"); + assertThat(result.attributes().get("foo")).isEqualTo("bar"); assertThat(result.attributes().get("baz")).isEqualTo("qux"); + assertThat(result.attributes().get("quux")).isEqualTo("quuz"); + + assertThat(result.params().getFirst("foo")).isEqualTo("bar"); + assertThat(result.params().getFirst("baz")).isEqualTo("qux"); + assertThat(result.params().getFirst("quux")).isEqualTo("quuz"); + + assertThat(result.remoteAddress()).contains(new InetSocketAddress("127.0.0.1", 80)); String body = result.body(String.class); assertThat(body).isEqualTo("baz");