diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java index da8569d78c0..663f0f06348 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java @@ -69,13 +69,16 @@ public abstract class CorsUtils { } URI uri = request.getURI(); + String actualScheme = uri.getScheme(); String actualHost = uri.getHost(); int actualPort = getPort(uri.getScheme(), uri.getPort()); + Assert.notNull(actualScheme, "Actual request scheme must not be null"); Assert.notNull(actualHost, "Actual request host must not be null"); Assert.isTrue(actualPort != -1, "Actual request port must not be undefined"); UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); - return (actualHost.equals(originUrl.getHost()) && + return (actualScheme.equals(originUrl.getScheme()) && + actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl.getScheme(), originUrl.getPort())); } diff --git a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java index 4139acf94ac..5face07f544 100644 --- a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java +++ b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java @@ -813,7 +813,8 @@ public abstract class WebUtils { } UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); - return (ObjectUtils.nullSafeEquals(host, originUrl.getHost()) && + return (ObjectUtils.nullSafeEquals(scheme, originUrl.getScheme()) && + ObjectUtils.nullSafeEquals(host, originUrl.getHost()) && getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort())); } diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java index 96ee57a298e..617bf825461 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java @@ -92,6 +92,15 @@ public class CorsUtilsTests { testWithForwardedHeader(server, 123, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456"); } + @Test // SPR-16362 + public void isSameOriginWithDifferentSchemes() { + MockServerHttpRequest request = MockServerHttpRequest + .get("http://mydomain1.com") + .header(HttpHeaders.ORIGIN, "https://mydomain1.com") + .build(); + assertFalse(CorsUtils.isSameOrigin(request)); + } + private void testWithXForwardedHeaders(String serverName, int port, String forwardedProto, String forwardedHost, int forwardedPort, String originHeader) { diff --git a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java index 1e384878418..99a6736f1ae 100644 --- a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java @@ -105,39 +105,40 @@ public class WebUtilsTests { @Test public void isSameOrigin() { - assertTrue(checkSameOrigin("mydomain1.com", -1, "http://mydomain1.com")); - assertTrue(checkSameOrigin("mydomain1.com", -1, "http://mydomain1.com:80")); - assertTrue(checkSameOrigin("mydomain1.com", 443, "https://mydomain1.com")); - assertTrue(checkSameOrigin("mydomain1.com", 443, "https://mydomain1.com:443")); - assertTrue(checkSameOrigin("mydomain1.com", 123, "http://mydomain1.com:123")); - assertTrue(checkSameOrigin("mydomain1.com", -1, "ws://mydomain1.com")); - assertTrue(checkSameOrigin("mydomain1.com", 443, "wss://mydomain1.com")); - - assertFalse(checkSameOrigin("mydomain1.com", -1, "http://mydomain2.com")); - assertFalse(checkSameOrigin("mydomain1.com", -1, "https://mydomain1.com")); - assertFalse(checkSameOrigin("mydomain1.com", -1, "invalid-origin")); + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com")); + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com:80")); + assertTrue(checkSameOrigin("https", "mydomain1.com", 443, "https://mydomain1.com")); + assertTrue(checkSameOrigin("https", "mydomain1.com", 443, "https://mydomain1.com:443")); + assertTrue(checkSameOrigin("http", "mydomain1.com", 123, "http://mydomain1.com:123")); + assertTrue(checkSameOrigin("ws", "mydomain1.com", -1, "ws://mydomain1.com")); + assertTrue(checkSameOrigin("wss", "mydomain1.com", 443, "wss://mydomain1.com")); + + assertFalse(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain2.com")); + assertFalse(checkSameOrigin("http", "mydomain1.com", -1, "https://mydomain1.com")); + assertFalse(checkSameOrigin("http", "mydomain1.com", -1, "invalid-origin")); + assertFalse(checkSameOrigin("https", "mydomain1.com", -1, "http://mydomain1.com")); // Handling of invalid origins as described in SPR-13478 - assertTrue(checkSameOrigin("mydomain1.com", -1, "http://mydomain1.com/")); - assertTrue(checkSameOrigin("mydomain1.com", -1, "http://mydomain1.com:80/")); - assertTrue(checkSameOrigin("mydomain1.com", -1, "http://mydomain1.com/path")); - assertTrue(checkSameOrigin("mydomain1.com", -1, "http://mydomain1.com:80/path")); - assertFalse(checkSameOrigin("mydomain2.com", -1, "http://mydomain1.com/")); - assertFalse(checkSameOrigin("mydomain2.com", -1, "http://mydomain1.com:80/")); - assertFalse(checkSameOrigin("mydomain2.com", -1, "http://mydomain1.com/path")); - assertFalse(checkSameOrigin("mydomain2.com", -1, "http://mydomain1.com:80/path")); + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com/")); + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com:80/")); + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com/path")); + assertTrue(checkSameOrigin("http", "mydomain1.com", -1, "http://mydomain1.com:80/path")); + assertFalse(checkSameOrigin("http", "mydomain2.com", -1, "http://mydomain1.com/")); + assertFalse(checkSameOrigin("http", "mydomain2.com", -1, "http://mydomain1.com:80/")); + assertFalse(checkSameOrigin("http", "mydomain2.com", -1, "http://mydomain1.com/path")); + assertFalse(checkSameOrigin("http", "mydomain2.com", -1, "http://mydomain1.com:80/path")); // Handling of IPv6 hosts as described in SPR-13525 - assertTrue(checkSameOrigin("[::1]", -1, "http://[::1]")); - assertTrue(checkSameOrigin("[::1]", 8080, "http://[::1]:8080")); - assertTrue(checkSameOrigin( + assertTrue(checkSameOrigin("http", "[::1]", -1, "http://[::1]")); + assertTrue(checkSameOrigin("http", "[::1]", 8080, "http://[::1]:8080")); + assertTrue(checkSameOrigin("http", "[2001:0db8:0000:85a3:0000:0000:ac1f:8001]", -1, "http://[2001:0db8:0000:85a3:0000:0000:ac1f:8001]")); - assertTrue(checkSameOrigin( + assertTrue(checkSameOrigin("http", "[2001:0db8:0000:85a3:0000:0000:ac1f:8001]", 8080, "http://[2001:0db8:0000:85a3:0000:0000:ac1f:8001]:8080")); - assertFalse(checkSameOrigin("[::1]", -1, "http://[::1]:8080")); - assertFalse(checkSameOrigin("[::1]", 8080, + assertFalse(checkSameOrigin("http", "[::1]", -1, "http://[::1]:8080")); + assertFalse(checkSameOrigin("http", "[::1]", 8080, "http://[2001:0db8:0000:85a3:0000:0000:ac1f:8001]:8080")); } @@ -175,9 +176,10 @@ public class WebUtilsTests { return WebUtils.isValidOrigin(request, allowed); } - private boolean checkSameOrigin(String serverName, int port, String originHeader) { + private boolean checkSameOrigin(String scheme, String serverName, int port, String originHeader) { MockHttpServletRequest servletRequest = new MockHttpServletRequest(); ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); + servletRequest.setScheme(scheme); servletRequest.setServerName(serverName); if (port != -1) { servletRequest.setServerPort(port);