diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java index f2a9289c4c3..5ddbf6d96c2 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java @@ -64,20 +64,19 @@ public abstract class CorsUtils { UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request); UriComponents actualUrl = urlBuilder.build(); String actualHost = actualUrl.getHost(); - int actualPort = getPort(actualUrl); + int actualPort = getPort(actualUrl.getScheme(), actualUrl.getPort()); Assert.notNull(actualHost, "Actual request host must not be null"); Assert.isTrue(actualPort != -1, "Actual request port must not be undefined"); UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); - return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl)); + return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl.getScheme(), originUrl.getPort())); } - private static int getPort(UriComponents uri) { - int port = uri.getPort(); + private static int getPort(String scheme, int port) { if (port == -1) { - if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) { + if ("http".equals(scheme) || "ws".equals(scheme)) { port = 80; } - else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) { + else if ("https".equals(scheme) || "wss".equals(scheme)) { port = 443; } } diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index f228743fee0..531cfe68aa8 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -118,10 +118,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @Override protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { - Enumeration names = request.getHeaderNames(); - while (names.hasMoreElements()) { - String name = names.nextElement(); - if (FORWARDED_HEADER_NAMES.contains(name)) { + for (String headerName : FORWARDED_HEADER_NAMES) { + if (request.getHeader(headerName) != null) { return false; } } diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java index 9fae417b128..d48484a1aae 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java @@ -17,8 +17,7 @@ package org.springframework.web.filter.reactive; import java.net.URI; -import java.util.Collections; -import java.util.Locale; +import java.util.LinkedHashSet; import java.util.Set; import reactor.core.publisher.Mono; @@ -26,7 +25,6 @@ import reactor.core.publisher.Mono; import org.springframework.http.HttpHeaders; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.lang.Nullable; -import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; @@ -47,8 +45,7 @@ import org.springframework.web.util.UriComponentsBuilder; */ public class ForwardedHeaderFilter implements WebFilter { - private static final Set FORWARDED_HEADER_NAMES = - Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(5, Locale.ENGLISH)); + private static final Set FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5); static { FORWARDED_HEADER_NAMES.add("Forwarded"); @@ -104,8 +101,13 @@ public class ForwardedHeaderFilter implements WebFilter { } private boolean shouldNotFilter(ServerHttpRequest request) { - return request.getHeaders().keySet().stream() - .noneMatch(FORWARDED_HEADER_NAMES::contains); + HttpHeaders headers = request.getHeaders(); + for (String headerName : FORWARDED_HEADER_NAMES) { + if (headers.containsKey(headerName)) { + return false; + } + } + return true; } @Nullable diff --git a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java index c597cedbbe8..1904c225bf5 100644 --- a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java +++ b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java @@ -20,7 +20,9 @@ import java.io.File; import java.io.FileNotFoundException; import java.util.Collection; import java.util.Enumeration; +import java.util.LinkedHashSet; import java.util.Map; +import java.util.Set; import java.util.StringTokenizer; import java.util.TreeMap; import javax.servlet.ServletContext; @@ -33,6 +35,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpRequest; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.lang.Nullable; @@ -135,6 +138,16 @@ public abstract class WebUtils { /** Key for the mutex session attribute */ public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX"; + private static final Set FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5); + + static { + FORWARDED_HEADER_NAMES.add("Forwarded"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Host"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Port"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix"); + } + /** * Set a system property to the web application root directory. @@ -693,36 +706,60 @@ public abstract class WebUtils { * @since 4.2 */ public static boolean isSameOrigin(HttpRequest request) { - String origin = request.getHeaders().getOrigin(); + HttpHeaders headers = request.getHeaders(); + String origin = headers.getOrigin(); if (origin == null) { return true; } - UriComponentsBuilder urlBuilder; + String scheme; + String host; + int port; if (request instanceof ServletServerHttpRequest) { // Build more efficiently if we can: we only need scheme, host, port for origin comparison HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); - urlBuilder = new UriComponentsBuilder(). - scheme(servletRequest.getScheme()). - host(servletRequest.getServerName()). - port(servletRequest.getServerPort()). - adaptFromForwardedHeaders(request.getHeaders()); + scheme = servletRequest.getScheme(); + host = servletRequest.getServerName(); + port = servletRequest.getServerPort(); + + if(containsForwardedHeaders(servletRequest)) { + UriComponents actualUrl = new UriComponentsBuilder() + .scheme(scheme) + .host(host) + .port(port) + .adaptFromForwardedHeaders(headers) + .build(); + scheme = actualUrl.getScheme(); + host = actualUrl.getHost(); + port = actualUrl.getPort(); + } } else { - urlBuilder = UriComponentsBuilder.fromHttpRequest(request); + UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build(); + scheme = actualUrl.getScheme(); + host = actualUrl.getHost(); + port = actualUrl.getPort(); } - UriComponents actualUrl = urlBuilder.build(); + UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); - return (ObjectUtils.nullSafeEquals(actualUrl.getHost(), originUrl.getHost()) && - getPort(actualUrl) == getPort(originUrl)); + return (ObjectUtils.nullSafeEquals(host, originUrl.getHost()) && + getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort())); + } + + private static boolean containsForwardedHeaders(HttpServletRequest request) { + for (String headerName : FORWARDED_HEADER_NAMES) { + if (request.getHeader(headerName) != null) { + return true; + } + } + return false; } - private static int getPort(UriComponents uri) { - int port = uri.getPort(); + private static int getPort(String scheme, int port) { if (port == -1) { - if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) { + if ("http".equals(scheme) || "ws".equals(scheme)) { port = 80; } - else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) { + else if ("https".equals(scheme) || "wss".equals(scheme)) { port = 443; } } diff --git a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java index 77619e95879..ad92f4e9538 100644 --- a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java @@ -168,7 +168,7 @@ public class WebUtilsTests { if (port != -1) { servletRequest.setServerPort(port); } - request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); return WebUtils.isValidOrigin(request, allowed); } @@ -179,7 +179,7 @@ public class WebUtilsTests { if (port != -1) { servletRequest.setServerPort(port); } - request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); return WebUtils.isSameOrigin(request); } @@ -191,15 +191,15 @@ public class WebUtilsTests { servletRequest.setServerPort(port); } if (forwardedProto != null) { - request.getHeaders().set("X-Forwarded-Proto", forwardedProto); + servletRequest.addHeader("X-Forwarded-Proto", forwardedProto); } if (forwardedHost != null) { - request.getHeaders().set("X-Forwarded-Host", forwardedHost); + servletRequest.addHeader("X-Forwarded-Host", forwardedHost); } if (forwardedPort != -1) { - request.getHeaders().set("X-Forwarded-Port", String.valueOf(forwardedPort)); + servletRequest.addHeader("X-Forwarded-Port", String.valueOf(forwardedPort)); } - request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); return WebUtils.isSameOrigin(request); } @@ -210,8 +210,8 @@ public class WebUtilsTests { if (port != -1) { servletRequest.setServerPort(port); } - request.getHeaders().set("Forwarded", forwardedHeader); - request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + servletRequest.addHeader("Forwarded", forwardedHeader); + servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); return WebUtils.isSameOrigin(request); }