Browse Source

Polishing

Optimize same origin check when the request is an instance of
ServletServerHttpRequest and when there is no forwarded headers.

This commit also optimizes the getPort methods and ForwardedHeaderFilter
forwarded headers checks.

Issue: SPR-16262
pull/1637/merge
sdeleuze 8 years ago
parent
commit
9c7de232b8
  1. 11
      spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java
  2. 6
      spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java
  3. 16
      spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java
  4. 67
      spring-web/src/main/java/org/springframework/web/util/WebUtils.java
  5. 16
      spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java

11
spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java

@ -64,20 +64,19 @@ public abstract class CorsUtils { @@ -64,20 +64,19 @@ public abstract class CorsUtils {
UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
UriComponents actualUrl = urlBuilder.build();
String actualHost = actualUrl.getHost();
int actualPort = getPort(actualUrl);
int actualPort = getPort(actualUrl.getScheme(), actualUrl.getPort());
Assert.notNull(actualHost, "Actual request host must not be null");
Assert.isTrue(actualPort != -1, "Actual request port must not be undefined");
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl));
return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl.getScheme(), originUrl.getPort()));
}
private static int getPort(UriComponents uri) {
int port = uri.getPort();
private static int getPort(String scheme, int port) {
if (port == -1) {
if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) {
if ("http".equals(scheme) || "ws".equals(scheme)) {
port = 80;
}
else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) {
else if ("https".equals(scheme) || "wss".equals(scheme)) {
port = 443;
}
}

6
spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java

@ -118,10 +118,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -118,10 +118,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
if (FORWARDED_HEADER_NAMES.contains(name)) {
for (String headerName : FORWARDED_HEADER_NAMES) {
if (request.getHeader(headerName) != null) {
return false;
}
}

16
spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java

@ -17,8 +17,7 @@ @@ -17,8 +17,7 @@
package org.springframework.web.filter.reactive;
import java.net.URI;
import java.util.Collections;
import java.util.Locale;
import java.util.LinkedHashSet;
import java.util.Set;
import reactor.core.publisher.Mono;
@ -26,7 +25,6 @@ import reactor.core.publisher.Mono; @@ -26,7 +25,6 @@ import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
@ -47,8 +45,7 @@ import org.springframework.web.util.UriComponentsBuilder; @@ -47,8 +45,7 @@ import org.springframework.web.util.UriComponentsBuilder;
*/
public class ForwardedHeaderFilter implements WebFilter {
private static final Set<String> FORWARDED_HEADER_NAMES =
Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(5, Locale.ENGLISH));
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
static {
FORWARDED_HEADER_NAMES.add("Forwarded");
@ -104,8 +101,13 @@ public class ForwardedHeaderFilter implements WebFilter { @@ -104,8 +101,13 @@ public class ForwardedHeaderFilter implements WebFilter {
}
private boolean shouldNotFilter(ServerHttpRequest request) {
return request.getHeaders().keySet().stream()
.noneMatch(FORWARDED_HEADER_NAMES::contains);
HttpHeaders headers = request.getHeaders();
for (String headerName : FORWARDED_HEADER_NAMES) {
if (headers.containsKey(headerName)) {
return false;
}
}
return true;
}
@Nullable

67
spring-web/src/main/java/org/springframework/web/util/WebUtils.java

@ -20,7 +20,9 @@ import java.io.File; @@ -20,7 +20,9 @@ import java.io.File;
import java.io.FileNotFoundException;
import java.util.Collection;
import java.util.Enumeration;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.StringTokenizer;
import java.util.TreeMap;
import javax.servlet.ServletContext;
@ -33,6 +35,7 @@ import javax.servlet.http.HttpServletRequest; @@ -33,6 +35,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpRequest;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable;
@ -135,6 +138,16 @@ public abstract class WebUtils { @@ -135,6 +138,16 @@ public abstract class WebUtils {
/** Key for the mutex session attribute */
public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX";
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
static {
FORWARDED_HEADER_NAMES.add("Forwarded");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Host");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Port");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix");
}
/**
* Set a system property to the web application root directory.
@ -693,36 +706,60 @@ public abstract class WebUtils { @@ -693,36 +706,60 @@ public abstract class WebUtils {
* @since 4.2
*/
public static boolean isSameOrigin(HttpRequest request) {
String origin = request.getHeaders().getOrigin();
HttpHeaders headers = request.getHeaders();
String origin = headers.getOrigin();
if (origin == null) {
return true;
}
UriComponentsBuilder urlBuilder;
String scheme;
String host;
int port;
if (request instanceof ServletServerHttpRequest) {
// Build more efficiently if we can: we only need scheme, host, port for origin comparison
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
urlBuilder = new UriComponentsBuilder().
scheme(servletRequest.getScheme()).
host(servletRequest.getServerName()).
port(servletRequest.getServerPort()).
adaptFromForwardedHeaders(request.getHeaders());
scheme = servletRequest.getScheme();
host = servletRequest.getServerName();
port = servletRequest.getServerPort();
if(containsForwardedHeaders(servletRequest)) {
UriComponents actualUrl = new UriComponentsBuilder()
.scheme(scheme)
.host(host)
.port(port)
.adaptFromForwardedHeaders(headers)
.build();
scheme = actualUrl.getScheme();
host = actualUrl.getHost();
port = actualUrl.getPort();
}
}
else {
urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build();
scheme = actualUrl.getScheme();
host = actualUrl.getHost();
port = actualUrl.getPort();
}
UriComponents actualUrl = urlBuilder.build();
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
return (ObjectUtils.nullSafeEquals(actualUrl.getHost(), originUrl.getHost()) &&
getPort(actualUrl) == getPort(originUrl));
return (ObjectUtils.nullSafeEquals(host, originUrl.getHost()) &&
getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort()));
}
private static boolean containsForwardedHeaders(HttpServletRequest request) {
for (String headerName : FORWARDED_HEADER_NAMES) {
if (request.getHeader(headerName) != null) {
return true;
}
}
return false;
}
private static int getPort(UriComponents uri) {
int port = uri.getPort();
private static int getPort(String scheme, int port) {
if (port == -1) {
if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) {
if ("http".equals(scheme) || "ws".equals(scheme)) {
port = 80;
}
else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) {
else if ("https".equals(scheme) || "wss".equals(scheme)) {
port = 443;
}
}

16
spring-web/src/test/java/org/springframework/web/util/WebUtilsTests.java

@ -168,7 +168,7 @@ public class WebUtilsTests { @@ -168,7 +168,7 @@ public class WebUtilsTests {
if (port != -1) {
servletRequest.setServerPort(port);
}
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isValidOrigin(request, allowed);
}
@ -179,7 +179,7 @@ public class WebUtilsTests { @@ -179,7 +179,7 @@ public class WebUtilsTests {
if (port != -1) {
servletRequest.setServerPort(port);
}
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isSameOrigin(request);
}
@ -191,15 +191,15 @@ public class WebUtilsTests { @@ -191,15 +191,15 @@ public class WebUtilsTests {
servletRequest.setServerPort(port);
}
if (forwardedProto != null) {
request.getHeaders().set("X-Forwarded-Proto", forwardedProto);
servletRequest.addHeader("X-Forwarded-Proto", forwardedProto);
}
if (forwardedHost != null) {
request.getHeaders().set("X-Forwarded-Host", forwardedHost);
servletRequest.addHeader("X-Forwarded-Host", forwardedHost);
}
if (forwardedPort != -1) {
request.getHeaders().set("X-Forwarded-Port", String.valueOf(forwardedPort));
servletRequest.addHeader("X-Forwarded-Port", String.valueOf(forwardedPort));
}
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isSameOrigin(request);
}
@ -210,8 +210,8 @@ public class WebUtilsTests { @@ -210,8 +210,8 @@ public class WebUtilsTests {
if (port != -1) {
servletRequest.setServerPort(port);
}
request.getHeaders().set("Forwarded", forwardedHeader);
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
servletRequest.addHeader("Forwarded", forwardedHeader);
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isSameOrigin(request);
}

Loading…
Cancel
Save