diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java index 29810935a6c..c68b5135887 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -654,11 +654,14 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public String getServerName() { - String host = getHeader(HttpHeaders.HOST); + String rawHostHeader = getHeader(HttpHeaders.HOST); + String host = rawHostHeader; if (host != null) { host = host.trim(); if (host.startsWith("[")) { - host = host.substring(1, host.indexOf(']')); + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, () -> "Invalid Host header: " + rawHostHeader); + host = host.substring(0, indexOfClosingBracket + 1); } else if (host.contains(":")) { host = host.substring(0, host.indexOf(':')); @@ -676,12 +679,15 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public int getServerPort() { - String host = getHeader(HttpHeaders.HOST); + String rawHostHeader = getHeader(HttpHeaders.HOST); + String host = rawHostHeader; if (host != null) { host = host.trim(); int idx; if (host.startsWith("[")) { - idx = host.indexOf(':', host.indexOf(']')); + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, () -> "Invalid Host header: " + rawHostHeader); + idx = host.indexOf(':', indexOfClosingBracket); } else { idx = host.indexOf(':'); diff --git a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java index 502a807592c..15b1c43e085 100644 --- a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java +++ b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.mock.web; import java.io.IOException; +import java.net.URL; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; @@ -37,6 +38,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.util.FileCopyUtils; import org.springframework.util.StreamUtils; +import static org.hamcrest.CoreMatchers.startsWith; import static org.junit.Assert.*; /** @@ -347,16 +349,23 @@ public class MockHttpServletRequestTests { @Test public void getServerNameViaHostHeaderAsIpv6AddressWithoutPort() { - String ipv6Address = "[2001:db8:0:1]"; - request.addHeader(HOST, ipv6Address); - assertEquals("2001:db8:0:1", request.getServerName()); + String host = "[2001:db8:0:1]"; + request.addHeader(HOST, host); + assertEquals(host, request.getServerName()); } @Test public void getServerNameViaHostHeaderAsIpv6AddressWithPort() { - String ipv6Address = "[2001:db8:0:1]:8081"; - request.addHeader(HOST, ipv6Address); - assertEquals("2001:db8:0:1", request.getServerName()); + request.addHeader(HOST, "[2001:db8:0:1]:8081"); + assertEquals("[2001:db8:0:1]", request.getServerName()); + } + + @Test + public void getServerNameWithInvalidIpv6AddressViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd"); // missing closing bracket + exception.expect(IllegalStateException.class); + exception.expectMessage(startsWith("Invalid Host header: ")); + request.getServerName(); } @Test @@ -370,6 +379,22 @@ public class MockHttpServletRequestTests { assertEquals(8080, request.getServerPort()); } + @Test + public void getServerPortWithInvalidIpv6AddressViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd:8080"); // missing closing bracket + exception.expect(IllegalStateException.class); + exception.expectMessage(startsWith("Invalid Host header: ")); + request.getServerPort(); + } + + @Test + public void getServerPortWithIpv6AddressAndInvalidPortViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd]:bogus"); // "bogus" is not a port number + exception.expect(NumberFormatException.class); + exception.expectMessage("bogus"); + request.getServerPort(); + } + @Test public void getServerPortViaHostHeaderAsIpv6AddressWithoutPort() { String testServer = "[2001:db8:0:1]"; @@ -434,6 +459,43 @@ public class MockHttpServletRequestTests { assertEquals("http://" + testServer, requestURL.toString()); } + @Test + public void getRequestURLWithIpv6AddressViaServerNameWithoutPort() throws Exception { + request.setServerName("[::ffff:abcd:abcd]"); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertEquals("http://[::ffff:abcd:abcd]", url.toString()); + } + + @Test + public void getRequestURLWithIpv6AddressViaServerNameWithPort() throws Exception { + request.setServerName("[::ffff:abcd:abcd]"); + request.setServerPort(9999); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertEquals("http://[::ffff:abcd:abcd]:9999", url.toString()); + } + + @Test + public void getRequestURLWithInvalidIpv6AddressViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd"); // missing closing bracket + exception.expect(IllegalStateException.class); + exception.expectMessage(startsWith("Invalid Host header: ")); + request.getRequestURL(); + } + + @Test + public void getRequestURLWithIpv6AddressViaHostHeaderWithoutPort() throws Exception { + request.addHeader(HOST, "[::ffff:abcd:abcd]"); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertEquals("http://[::ffff:abcd:abcd]", url.toString()); + } + + @Test + public void getRequestURLWithIpv6AddressViaHostHeaderWithPort() throws Exception { + request.addHeader(HOST, "[::ffff:abcd:abcd]:9999"); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertEquals("http://[::ffff:abcd:abcd]:9999", url.toString()); + } + @Test public void getRequestURLWithNullRequestUri() { request.setRequestURI(null); diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java index bdfb905190d..bfffe31ac4a 100644 --- a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -654,11 +654,14 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public String getServerName() { - String host = getHeader(HttpHeaders.HOST); + String rawHostHeader = getHeader(HttpHeaders.HOST); + String host = rawHostHeader; if (host != null) { host = host.trim(); if (host.startsWith("[")) { - host = host.substring(1, host.indexOf(']')); + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, () -> "Invalid Host header: " + rawHostHeader); + host = host.substring(0, indexOfClosingBracket + 1); } else if (host.contains(":")) { host = host.substring(0, host.indexOf(':')); @@ -676,12 +679,15 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public int getServerPort() { - String host = getHeader(HttpHeaders.HOST); + String rawHostHeader = getHeader(HttpHeaders.HOST); + String host = rawHostHeader; if (host != null) { host = host.trim(); int idx; if (host.startsWith("[")) { - idx = host.indexOf(':', host.indexOf(']')); + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, () -> "Invalid Host header: " + rawHostHeader); + idx = host.indexOf(':', indexOfClosingBracket); } else { idx = host.indexOf(':');