diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java index 4a6d50d7f84..cb03c3f179c 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java @@ -1199,13 +1199,18 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public StringBuffer getRequestURL() { - StringBuffer url = new StringBuffer(this.scheme).append("://").append(this.serverName); - if (this.serverPort > 0 && ((HTTP.equalsIgnoreCase(this.scheme) && this.serverPort != 80) || - (HTTPS.equalsIgnoreCase(this.scheme) && this.serverPort != 443))) { - url.append(':').append(this.serverPort); + String scheme = getScheme(); + String server = getServerName(); + int port = getServerPort(); + String uri = getRequestURI(); + + StringBuffer url = new StringBuffer(scheme).append("://").append(server); + if (port > 0 && ((HTTP.equalsIgnoreCase(scheme) && port != 80) || + (HTTPS.equalsIgnoreCase(scheme) && port != 443))) { + url.append(':').append(port); } - if (StringUtils.hasText(getRequestURI())) { - url.append(getRequestURI()); + if (StringUtils.hasText(uri)) { + url.append(uri); } return url; } @@ -1318,12 +1323,12 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override @Nullable - public Part getPart(String name) throws IOException, IllegalStateException, ServletException { + public Part getPart(String name) throws IOException, ServletException { return this.parts.getFirst(name); } @Override - public Collection getParts() throws IOException, IllegalStateException, ServletException { + public Collection getParts() throws IOException, ServletException { List result = new LinkedList<>(); for (List list : this.parts.values()) { result.addAll(list); diff --git a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java index 11173426f1e..64f7cab0c77 100644 --- a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java +++ b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java @@ -82,7 +82,7 @@ public class MockHttpServletRequestTests { } @Test - public void setContentAndGetContentAsByteArray() throws IOException { + public void setContentAndGetContentAsByteArray() { byte[] bytes = "request body".getBytes(); request.setContent(bytes); assertEquals(bytes.length, request.getContentLength()); @@ -152,9 +152,7 @@ public class MockHttpServletRequestTests { assertEquals("UTF-8", request.getCharacterEncoding()); } - // SPR-12677 - - @Test + @Test // SPR-12677 public void setContentTypeHeaderWithMoreComplexCharsetSyntax() { String contentType = "test/plain;charset=\"utf-8\";foo=\"charset=bar\";foocharset=bar;foo=bar"; request.addHeader("Content-Type", contentType); @@ -182,7 +180,7 @@ public class MockHttpServletRequestTests { } @Test - public void httpHeaderNameCasingIsPreserved() throws Exception { + public void httpHeaderNameCasingIsPreserved() { String headerName = "Header1"; request.addHeader(headerName, "value1"); Enumeration requestHeaders = request.getHeaderNames(); @@ -402,6 +400,22 @@ public class MockHttpServletRequestTests { assertEquals("http://localhost", requestURL.toString()); } + @Test // SPR-16138 + public void getRequestURLWithHostHeader() { + String testServer = "test.server"; + request.addHeader(HOST, testServer); + StringBuffer requestURL = request.getRequestURL(); + assertEquals("http://" + testServer, requestURL.toString()); + } + + @Test // SPR-16138 + public void getRequestURLWithHostHeaderAndPort() { + String testServer = "test.server:9999"; + request.addHeader(HOST, testServer); + StringBuffer requestURL = request.getRequestURL(); + assertEquals("http://" + testServer, requestURL.toString()); + } + @Test public void getRequestURLWithNullRequestUri() { request.setRequestURI(null); @@ -457,39 +471,39 @@ public class MockHttpServletRequestTests { } @Test - public void httpHeaderDate() throws Exception { + public void httpHeaderDate() { Date date = new Date(); request.addHeader(IF_MODIFIED_SINCE, date); assertEquals(date.getTime(), request.getDateHeader(IF_MODIFIED_SINCE)); } @Test - public void httpHeaderTimestamp() throws Exception { + public void httpHeaderTimestamp() { long timestamp = new Date().getTime(); request.addHeader(IF_MODIFIED_SINCE, timestamp); assertEquals(timestamp, request.getDateHeader(IF_MODIFIED_SINCE)); } @Test - public void httpHeaderRfcFormatedDate() throws Exception { + public void httpHeaderRfcFormatedDate() { request.addHeader(IF_MODIFIED_SINCE, "Tue, 21 Jul 2015 10:00:00 GMT"); assertEquals(1437472800000L, request.getDateHeader(IF_MODIFIED_SINCE)); } @Test - public void httpHeaderFirstVariantFormatedDate() throws Exception { + public void httpHeaderFirstVariantFormatedDate() { request.addHeader(IF_MODIFIED_SINCE, "Tue, 21-Jul-15 10:00:00 GMT"); assertEquals(1437472800000L, request.getDateHeader(IF_MODIFIED_SINCE)); } @Test - public void httpHeaderSecondVariantFormatedDate() throws Exception { + public void httpHeaderSecondVariantFormatedDate() { request.addHeader(IF_MODIFIED_SINCE, "Tue Jul 21 10:00:00 2015"); assertEquals(1437472800000L, request.getDateHeader(IF_MODIFIED_SINCE)); } @Test(expected = IllegalArgumentException.class) - public void httpHeaderFormatedDateError() throws Exception { + public void httpHeaderFormatedDateError() { request.addHeader(IF_MODIFIED_SINCE, "This is not a date"); request.getDateHeader(IF_MODIFIED_SINCE); } diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java index 6ebc70fce6c..8e1eff4cd5b 100644 --- a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java @@ -62,6 +62,7 @@ import org.springframework.util.Assert; import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; import org.springframework.util.StreamUtils; import org.springframework.util.StringUtils; @@ -404,7 +405,7 @@ public class MockHttpServletRequest implements HttpServletRequest { /** * Get the content of the request body as a byte array. - * @return the content as a byte array, potentially {@code null} + * @return the content as a byte array (potentially {@code null}) * @since 5.0 * @see #setContent(byte[]) * @see #getContentAsString() @@ -588,7 +589,8 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public String getParameter(String name) { - String[] arr = (name != null ? this.parameters.get(name) : null); + Assert.notNull(name, "Parameter name must not be null"); + String[] arr = this.parameters.get(name); return (arr != null && arr.length > 0 ? arr[0] : null); } @@ -599,7 +601,8 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public String[] getParameterValues(String name) { - return (name != null ? this.parameters.get(name) : null); + Assert.notNull(name, "Parameter name must not be null"); + return this.parameters.get(name); } @Override @@ -927,10 +930,10 @@ public class MockHttpServletRequest implements HttpServletRequest { } public void setCookies(Cookie... cookies) { - this.cookies = cookies; + this.cookies = (ObjectUtils.isEmpty(cookies) ? null : cookies); this.headers.remove(HttpHeaders.COOKIE); - if (cookies != null) { - Arrays.stream(cookies) + if (this.cookies != null) { + Arrays.stream(this.cookies) .map(c -> c.getName() + '=' + (c.getValue() == null ? "" : c.getValue())) .forEach(value -> doAddHeaderValue(HttpHeaders.COOKIE, value, false)); } @@ -1164,13 +1167,18 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public StringBuffer getRequestURL() { - StringBuffer url = new StringBuffer(this.scheme).append("://").append(this.serverName); - if (this.serverPort > 0 && ((HTTP.equalsIgnoreCase(this.scheme) && this.serverPort != 80) || - (HTTPS.equalsIgnoreCase(this.scheme) && this.serverPort != 443))) { - url.append(':').append(this.serverPort); + String scheme = getScheme(); + String server = getServerName(); + int port = getServerPort(); + String uri = getRequestURI(); + + StringBuffer url = new StringBuffer(scheme).append("://").append(server); + if (port > 0 && ((HTTP.equalsIgnoreCase(scheme) && port != 80) || + (HTTPS.equalsIgnoreCase(scheme) && port != 443))) { + url.append(':').append(port); } - if (StringUtils.hasText(getRequestURI())) { - url.append(getRequestURI()); + if (StringUtils.hasText(uri)) { + url.append(uri); } return url; } @@ -1280,12 +1288,12 @@ public class MockHttpServletRequest implements HttpServletRequest { } @Override - public Part getPart(String name) throws IOException, IllegalStateException, ServletException { + public Part getPart(String name) throws IOException, ServletException { return this.parts.getFirst(name); } @Override - public Collection getParts() throws IOException, IllegalStateException, ServletException { + public Collection getParts() throws IOException, ServletException { List result = new LinkedList<>(); for (List list : this.parts.values()) { result.addAll(list);