diff --git a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java index 7a0fbf45435..b7d376f3d39 100644 --- a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java @@ -52,6 +52,7 @@ import org.springframework.web.util.HierarchicalUriComponents.PathComponent; * @author Phillip Webb * @author Oliver Gierke * @author Brian Clozel + * @author Sebastien Deleuze * @since 3.1 * @see #newInstance() * @see #fromPath(String) @@ -675,16 +676,23 @@ public class UriComponentsBuilder implements Cloneable { String forwardedHeader = headers.getFirst("Forwarded"); if (StringUtils.hasText(forwardedHeader)) { String forwardedToUse = StringUtils.tokenizeToStringArray(forwardedHeader, ",")[0]; - Matcher matcher = FORWARDED_HOST_PATTERN.matcher(forwardedToUse); + Matcher matcher = FORWARDED_PROTO_PATTERN.matcher(forwardedToUse); if (matcher.find()) { - adaptForwardedHost(matcher.group(1).trim()); + scheme(matcher.group(1).trim()); + port(null); } - matcher = FORWARDED_PROTO_PATTERN.matcher(forwardedToUse); + matcher = FORWARDED_HOST_PATTERN.matcher(forwardedToUse); if (matcher.find()) { - scheme(matcher.group(1).trim()); + adaptForwardedHost(matcher.group(1).trim()); } } else { + String protocolHeader = headers.getFirst("X-Forwarded-Proto"); + if (StringUtils.hasText(protocolHeader)) { + scheme(StringUtils.tokenizeToStringArray(protocolHeader, ",")[0]); + port(null); + } + String hostHeader = headers.getFirst("X-Forwarded-Host"); if (StringUtils.hasText(hostHeader)) { adaptForwardedHost(StringUtils.tokenizeToStringArray(hostHeader, ",")[0]); @@ -694,16 +702,11 @@ public class UriComponentsBuilder implements Cloneable { if (StringUtils.hasText(portHeader)) { port(Integer.parseInt(StringUtils.tokenizeToStringArray(portHeader, ",")[0])); } - - String protocolHeader = headers.getFirst("X-Forwarded-Proto"); - if (StringUtils.hasText(protocolHeader)) { - scheme(StringUtils.tokenizeToStringArray(protocolHeader, ",")[0]); - } } - if ((this.scheme.equals("http") && "80".equals(this.port)) || - (this.scheme.equals("https") && "443".equals(this.port))) { - this.port = null; + if ((this.scheme != null) && ((this.scheme.equals("http") && "80".equals(this.port)) || + (this.scheme.equals("https") && "443".equals(this.port)))) { + port(null); } return this; diff --git a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java index f7d0b295d18..62275e94dbf 100644 --- a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java +++ b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java @@ -812,7 +812,8 @@ public abstract class WebUtils { /** * Check if the request is a same-origin one, based on {@code Origin}, {@code Host}, - * {@code Forwarded} and {@code X-Forwarded-Host} headers. + * {@code Forwarded}, {@code X-Forwarded-Proto}, {@code X-Forwarded-Host} and + * @code X-Forwarded-Port} headers. * @return {@code true} if the request is a same-origin one, {@code false} in case * of cross-origin request * @since 4.2 diff --git a/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java b/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java index 946a755ba81..0ad771c4396 100644 --- a/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java @@ -401,6 +401,22 @@ public class UriComponentsBuilderTests { assertEquals(-1, result.getPort()); } + @Test // SPR-16262 + public void fromHttpRequestWithForwardedProtoWithDefaultPort() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("example.org"); + request.setServerPort(10080); + request.addHeader("X-Forwarded-Proto", "https"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("example.org", result.getHost()); + assertEquals(-1, result.getPort()); + } + @Test public void fromHttpRequestWithForwardedHostWithForwardedScheme() { @@ -837,4 +853,23 @@ public class UriComponentsBuilderTests { assertEquals(-1, result.getPort()); assertEquals("https://84.198.58.199/rest/mobile/users/1", result.toUriString()); } + + @Test // SPR-16262 + public void fromHttpRequestForwardedHeaderWithProtoAndServerPort() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Forwarded", "proto=https"); + request.setScheme("http"); + request.setServerPort(8080); + request.setServerName("example.com"); + request.setRequestURI("/rest/mobile/users/1"); + + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + + assertEquals("https", result.getScheme()); + assertEquals("example.com", result.getHost()); + assertEquals("/rest/mobile/users/1", result.getPath()); + assertEquals(-1, result.getPort()); + assertEquals("https://example.com/rest/mobile/users/1", result.toUriString()); + } } diff --git a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java index 1ad98fd255b..7a9118e5dac 100644 --- a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java @@ -160,6 +160,26 @@ public class WebUtilsTests { assertFalse(checkSameOrigin("[::1]", 8080, "http://[2001:0db8:0000:85a3:0000:0000:ac1f:8001]:8080")); } + @Test // SPR-16262 + public void isSameOriginWithXForwardedHeaders() { + assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", -1, "https", null, -1, "https://mydomain1.com")); + assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", 123, "https", null, -1, "https://mydomain1.com")); + assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", -1, "https", "mydomain2.com", -1, "https://mydomain2.com")); + assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", 123, "https", "mydomain2.com", -1, "https://mydomain2.com")); + assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", -1, "https", "mydomain2.com", 456, "https://mydomain2.com:456")); + assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", 123, "https", "mydomain2.com", 456, "https://mydomain2.com:456")); + } + + @Test // SPR-16262 + public void isSameOriginWithForwardedHeader() { + assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", -1, "proto=https", "https://mydomain1.com")); + assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", 123, "proto=https", "https://mydomain1.com")); + assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", -1, "proto=https; host=mydomain2.com", "https://mydomain2.com")); + assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", 123, "proto=https; host=mydomain2.com", "https://mydomain2.com")); + assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", -1, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456")); + assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", 123, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456")); + } + private boolean checkValidOrigin(String serverName, int port, String originHeader, List allowed) { MockHttpServletRequest servletRequest = new MockHttpServletRequest(); @@ -183,4 +203,36 @@ public class WebUtilsTests { return WebUtils.isSameOrigin(request); } + private boolean checkSameOriginWithXForwardedHeaders(String serverName, int port, String forwardedProto, String forwardedHost, int forwardedPort, String originHeader) { + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); + servletRequest.setServerName(serverName); + if (port != -1) { + servletRequest.setServerPort(port); + } + if (forwardedProto != null) { + request.getHeaders().set("X-Forwarded-Proto", forwardedProto); + } + if (forwardedHost != null) { + request.getHeaders().set("X-Forwarded-Host", forwardedHost); + } + if (forwardedPort != -1) { + request.getHeaders().set("X-Forwarded-Port", String.valueOf(forwardedPort)); + } + request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + return WebUtils.isSameOrigin(request); + } + + private boolean checkSameOriginWithForwardedHeader(String serverName, int port, String forwardedHeader, String originHeader) { + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); + servletRequest.setServerName(serverName); + if (port != -1) { + servletRequest.setServerPort(port); + } + request.getHeaders().set("Forwarded", forwardedHeader); + request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + return WebUtils.isSameOrigin(request); + } + }