From 9c7de232b844816dd550130da720587cd9d23b6b Mon Sep 17 00:00:00 2001 From: sdeleuze Date: Tue, 9 Jan 2018 12:40:34 +0100 Subject: [PATCH] Polishing Optimize same origin check when the request is an instance of ServletServerHttpRequest and when there is no forwarded headers. This commit also optimizes the getPort methods and ForwardedHeaderFilter forwarded headers checks. Issue: SPR-16262 --- .../web/cors/reactive/CorsUtils.java | 11 ++- .../web/filter/ForwardedHeaderFilter.java | 6 +- .../reactive/ForwardedHeaderFilter.java | 16 +++-- .../springframework/web/util/WebUtils.java | 67 ++++++++++++++----- .../web/util/WebUtilsTests.java | 16 ++--- 5 files changed, 76 insertions(+), 40 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java index f2a9289c4c3..5ddbf6d96c2 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java @@ -64,20 +64,19 @@ public abstract class CorsUtils { UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request); UriComponents actualUrl = urlBuilder.build(); String actualHost = actualUrl.getHost(); - int actualPort = getPort(actualUrl); + int actualPort = getPort(actualUrl.getScheme(), actualUrl.getPort()); Assert.notNull(actualHost, "Actual request host must not be null"); Assert.isTrue(actualPort != -1, "Actual request port must not be undefined"); UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); - return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl)); + return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl.getScheme(), originUrl.getPort())); } - private static int getPort(UriComponents uri) { - int port = uri.getPort(); + private static int getPort(String scheme, int port) { if (port == -1) { - if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) { + if ("http".equals(scheme) || "ws".equals(scheme)) { port = 80; } - else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) { + else if ("https".equals(scheme) || "wss".equals(scheme)) { port = 443; } } diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index f228743fee0..531cfe68aa8 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -118,10 +118,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @Override protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { - Enumeration names = request.getHeaderNames(); - while (names.hasMoreElements()) { - String name = names.nextElement(); - if (FORWARDED_HEADER_NAMES.contains(name)) { + for (String headerName : FORWARDED_HEADER_NAMES) { + if (request.getHeader(headerName) != null) { return false; } } diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java index 9fae417b128..d48484a1aae 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java @@ -17,8 +17,7 @@ package org.springframework.web.filter.reactive; import java.net.URI; -import java.util.Collections; -import java.util.Locale; +import java.util.LinkedHashSet; import java.util.Set; import reactor.core.publisher.Mono; @@ -26,7 +25,6 @@ import reactor.core.publisher.Mono; import org.springframework.http.HttpHeaders; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.lang.Nullable; -import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; @@ -47,8 +45,7 @@ import org.springframework.web.util.UriComponentsBuilder; */ public class ForwardedHeaderFilter implements WebFilter { - private static final Set FORWARDED_HEADER_NAMES = - Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(5, Locale.ENGLISH)); + private static final Set FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5); static { FORWARDED_HEADER_NAMES.add("Forwarded"); @@ -104,8 +101,13 @@ public class ForwardedHeaderFilter implements WebFilter { } private boolean shouldNotFilter(ServerHttpRequest request) { - return request.getHeaders().keySet().stream() - .noneMatch(FORWARDED_HEADER_NAMES::contains); + HttpHeaders headers = request.getHeaders(); + for (String headerName : FORWARDED_HEADER_NAMES) { + if (headers.containsKey(headerName)) { + return false; + } + } + return true; } @Nullable diff --git a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java index c597cedbbe8..1904c225bf5 100644 --- a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java +++ b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java @@ -20,7 +20,9 @@ import java.io.File; import java.io.FileNotFoundException; import java.util.Collection; import java.util.Enumeration; +import java.util.LinkedHashSet; import java.util.Map; +import java.util.Set; import java.util.StringTokenizer; import java.util.TreeMap; import javax.servlet.ServletContext; @@ -33,6 +35,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpRequest; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.lang.Nullable; @@ -135,6 +138,16 @@ public abstract class WebUtils { /** Key for the mutex session attribute */ public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX"; + private static final Set FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5); + + static { + FORWARDED_HEADER_NAMES.add("Forwarded"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Host"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Port"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix"); + } + /** * Set a system property to the web application root directory. @@ -693,36 +706,60 @@ public abstract class WebUtils { * @since 4.2 */ public static boolean isSameOrigin(HttpRequest request) { - String origin = request.getHeaders().getOrigin(); + HttpHeaders headers = request.getHeaders(); + String origin = headers.getOrigin(); if (origin == null) { return true; } - UriComponentsBuilder urlBuilder; + String scheme; + String host; + int port; if (request instanceof ServletServerHttpRequest) { // Build more efficiently if we can: we only need scheme, host, port for origin comparison HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); - urlBuilder = new UriComponentsBuilder(). - scheme(servletRequest.getScheme()). - host(servletRequest.getServerName()). - port(servletRequest.getServerPort()). - adaptFromForwardedHeaders(request.getHeaders()); + scheme = servletRequest.getScheme(); + host = servletRequest.getServerName(); + port = servletRequest.getServerPort(); + + if(containsForwardedHeaders(servletRequest)) { + UriComponents actualUrl = new UriComponentsBuilder() + .scheme(scheme) + .host(host) + .port(port) + .adaptFromForwardedHeaders(headers) + .build(); + scheme = actualUrl.getScheme(); + host = actualUrl.getHost(); + port = actualUrl.getPort(); + } } else { - urlBuilder = UriComponentsBuilder.fromHttpRequest(request); + UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build(); + scheme = actualUrl.getScheme(); + host = actualUrl.getHost(); + port = actualUrl.getPort(); } - UriComponents actualUrl = urlBuilder.build(); + UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); - return (ObjectUtils.nullSafeEquals(actualUrl.getHost(), originUrl.getHost()) && - getPort(actualUrl) == getPort(originUrl)); + return (ObjectUtils.nullSafeEquals(host, originUrl.getHost()) && + getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort())); + } + + private static boolean containsForwardedHeaders(HttpServletRequest request) { + for (String headerName : FORWARDED_HEADER_NAMES) { + if (request.getHeader(headerName) != null) { + return true; + } + } + return false; } - private static int getPort(UriComponents uri) { - int port = uri.getPort(); + private static int getPort(String scheme, int port) { if (port == -1) { - if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) { + if ("http".equals(scheme) || "ws".equals(scheme)) { port = 80; } - else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) { + else if ("https".equals(scheme) || "wss".equals(scheme)) { port = 443; } } diff --git a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java index 77619e95879..ad92f4e9538 100644 --- a/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java @@ -168,7 +168,7 @@ public class WebUtilsTests { if (port != -1) { servletRequest.setServerPort(port); } - request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); return WebUtils.isValidOrigin(request, allowed); } @@ -179,7 +179,7 @@ public class WebUtilsTests { if (port != -1) { servletRequest.setServerPort(port); } - request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); return WebUtils.isSameOrigin(request); } @@ -191,15 +191,15 @@ public class WebUtilsTests { servletRequest.setServerPort(port); } if (forwardedProto != null) { - request.getHeaders().set("X-Forwarded-Proto", forwardedProto); + servletRequest.addHeader("X-Forwarded-Proto", forwardedProto); } if (forwardedHost != null) { - request.getHeaders().set("X-Forwarded-Host", forwardedHost); + servletRequest.addHeader("X-Forwarded-Host", forwardedHost); } if (forwardedPort != -1) { - request.getHeaders().set("X-Forwarded-Port", String.valueOf(forwardedPort)); + servletRequest.addHeader("X-Forwarded-Port", String.valueOf(forwardedPort)); } - request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); return WebUtils.isSameOrigin(request); } @@ -210,8 +210,8 @@ public class WebUtilsTests { if (port != -1) { servletRequest.setServerPort(port); } - request.getHeaders().set("Forwarded", forwardedHeader); - request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); + servletRequest.addHeader("Forwarded", forwardedHeader); + servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); return WebUtils.isSameOrigin(request); }