diff --git a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java index 31c0162e100..7dea422f5aa 100644 --- a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpRequest; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; @@ -130,7 +131,7 @@ public class UriComponentsBuilder implements Cloneable { this.userInfo = other.userInfo; this.host = other.host; this.port = other.port; - this.pathBuilder = (CompositePathComponentBuilder) other.pathBuilder.clone(); + this.pathBuilder = other.pathBuilder.cloneBuilder(); this.queryParams.putAll(other.queryParams); this.fragment = other.fragment; } @@ -271,72 +272,17 @@ public class UriComponentsBuilder implements Cloneable { /** * Create a new {@code UriComponents} object from the URI associated with * the given HttpRequest while also overlaying with values from the headers - * "Forwarded" (RFC 7239, or - * "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" if "Forwarded" is - * not found. + * "Forwarded" (RFC 7239, + * or "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" if + * "Forwarded" is not found. * @param request the source request * @return the URI components of the URI * @since 4.1.5 */ public static UriComponentsBuilder fromHttpRequest(HttpRequest request) { - URI uri = request.getURI(); - UriComponentsBuilder builder = UriComponentsBuilder.fromUri(uri); - - String scheme = uri.getScheme(); - String host = uri.getHost(); - int port = uri.getPort(); - - String forwardedHeader = request.getHeaders().getFirst("Forwarded"); - if (StringUtils.hasText(forwardedHeader)) { - String forwardedToUse = StringUtils.commaDelimitedListToStringArray(forwardedHeader)[0]; - Matcher m = FORWARDED_HOST_PATTERN.matcher(forwardedToUse); - if (m.find()) { - host = m.group(1).trim(); - } - m = FORWARDED_PROTO_PATTERN.matcher(forwardedToUse); - if (m.find()) { - scheme = m.group(1).trim(); - } - } - else { - String hostHeader = request.getHeaders().getFirst("X-Forwarded-Host"); - if (StringUtils.hasText(hostHeader)) { - String[] hosts = StringUtils.commaDelimitedListToStringArray(hostHeader); - String hostToUse = hosts[0]; - if (hostToUse.contains(":")) { - String[] hostAndPort = StringUtils.split(hostToUse, ":"); - host = hostAndPort[0]; - port = Integer.parseInt(hostAndPort[1]); - } - else { - host = hostToUse; - port = -1; - } - } - - String portHeader = request.getHeaders().getFirst("X-Forwarded-Port"); - if (StringUtils.hasText(portHeader)) { - String[] ports = StringUtils.commaDelimitedListToStringArray(portHeader); - port = Integer.parseInt(ports[0]); - } - - String protocolHeader = request.getHeaders().getFirst("X-Forwarded-Proto"); - if (StringUtils.hasText(protocolHeader)) { - String[] protocols = StringUtils.commaDelimitedListToStringArray(protocolHeader); - scheme = protocols[0]; - } - } - - builder.scheme(scheme); - builder.host(host); - builder.port(null); - if (scheme.equals("http") && port != 80 || scheme.equals("https") && port != 443) { - builder.port(port); - } - return builder; + return fromUri(request.getURI()).adaptFromForwardedHeaders(request.getHeaders()); } - /** * Create an instance by parsing the "Origin" header of an HTTP request. * @see RFC 6454 @@ -463,18 +409,6 @@ public class UriComponentsBuilder implements Cloneable { return this; } - private void resetHierarchicalComponents() { - this.userInfo = null; - this.host = null; - this.port = null; - this.pathBuilder = new CompositePathComponentBuilder(); - this.queryParams.clear(); - } - - private void resetSchemeSpecificPart() { - this.ssp = null; - } - /** * Set the URI scheme. The given scheme may contain URI template variables, * and may also be {@code null} to clear the scheme of this builder. @@ -724,17 +658,103 @@ public class UriComponentsBuilder implements Cloneable { return this; } + /** + * Adapt this builder's scheme+host+port from the given headers, specifically + * "Forwarded" (RFC 7239, + * or "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" if + * "Forwarded" is not found. + * @param headers the HTTP headers to consider + * @return this UriComponentsBuilder + * @since 4.3 + */ + UriComponentsBuilder adaptFromForwardedHeaders(HttpHeaders headers) { + String forwardedHeader = headers.getFirst("Forwarded"); + if (StringUtils.hasText(forwardedHeader)) { + String forwardedToUse = StringUtils.commaDelimitedListToStringArray(forwardedHeader)[0]; + Matcher matcher = FORWARDED_HOST_PATTERN.matcher(forwardedToUse); + if (matcher.find()) { + host(matcher.group(1).trim()); + } + matcher = FORWARDED_PROTO_PATTERN.matcher(forwardedToUse); + if (matcher.find()) { + scheme(matcher.group(1).trim()); + } + } + else { + String hostHeader = headers.getFirst("X-Forwarded-Host"); + if (StringUtils.hasText(hostHeader)) { + String[] hosts = StringUtils.commaDelimitedListToStringArray(hostHeader); + String hostToUse = hosts[0]; + if (hostToUse.contains(":")) { + String[] hostAndPort = StringUtils.split(hostToUse, ":"); + host(hostAndPort[0]); + port(Integer.parseInt(hostAndPort[1])); + } + else { + host(hostToUse); + port(null); + } + } + + String portHeader = headers.getFirst("X-Forwarded-Port"); + if (StringUtils.hasText(portHeader)) { + String[] ports = StringUtils.commaDelimitedListToStringArray(portHeader); + port(Integer.parseInt(ports[0])); + } + + String protocolHeader = headers.getFirst("X-Forwarded-Proto"); + if (StringUtils.hasText(protocolHeader)) { + String[] protocols = StringUtils.commaDelimitedListToStringArray(protocolHeader); + scheme(protocols[0]); + } + } + + if ((this.scheme.equals("http") && "80".equals(this.port)) || + (this.scheme.equals("https") && "443".equals(this.port))) { + this.port = null; + } + + return this; + } + + private void resetHierarchicalComponents() { + this.userInfo = null; + this.host = null; + this.port = null; + this.pathBuilder = new CompositePathComponentBuilder(); + this.queryParams.clear(); + } + + private void resetSchemeSpecificPart() { + this.ssp = null; + } + + + /** + * Public declaration of Object's {@code clone()} method. + * Delegates to {@link #cloneBuilder()}. + * @see Object#clone() + */ @Override public Object clone() { + return cloneBuilder(); + } + + /** + * Clone this {@code UriComponentsBuilder}. + * @return the cloned {@code UriComponentsBuilder} object + * @since 4.3 + */ + public UriComponentsBuilder cloneBuilder() { return new UriComponentsBuilder(this); } - private interface PathComponentBuilder extends Cloneable { + private interface PathComponentBuilder { PathComponent build(); - Object clone(); + PathComponentBuilder cloneBuilder(); } @@ -810,10 +830,10 @@ public class UriComponentsBuilder implements Cloneable { } @Override - public Object clone() { + public CompositePathComponentBuilder cloneBuilder() { CompositePathComponentBuilder compositeBuilder = new CompositePathComponentBuilder(); for (PathComponentBuilder builder : this.builders) { - compositeBuilder.builders.add((PathComponentBuilder) builder.clone()); + compositeBuilder.builders.add(builder.cloneBuilder()); } return compositeBuilder; } @@ -852,7 +872,7 @@ public class UriComponentsBuilder implements Cloneable { } @Override - public Object clone() { + public FullPathComponentBuilder cloneBuilder() { FullPathComponentBuilder builder = new FullPathComponentBuilder(); builder.append(this.path.toString()); return builder; @@ -879,7 +899,7 @@ public class UriComponentsBuilder implements Cloneable { } @Override - public Object clone() { + public PathSegmentComponentBuilder cloneBuilder() { PathSegmentComponentBuilder builder = new PathSegmentComponentBuilder(); builder.pathSegments.addAll(this.pathSegments); return builder; diff --git a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java index 76c5dd042e2..6d8062fbdc7 100644 --- a/spring-web/src/main/java/org/springframework/web/util/WebUtils.java +++ b/spring-web/src/main/java/org/springframework/web/util/WebUtils.java @@ -34,6 +34,7 @@ import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; import org.springframework.http.HttpRequest; +import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; @@ -811,18 +812,31 @@ public abstract class WebUtils { if (origin == null) { return true; } - UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build(); + UriComponentsBuilder urlBuilder; + if (request instanceof ServletServerHttpRequest) { + // Build more efficiently if we can: we only need scheme, host, port for origin comparison + HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + urlBuilder = new UriComponentsBuilder(). + scheme(servletRequest.getScheme()). + host(servletRequest.getServerName()). + port(servletRequest.getServerPort()). + adaptFromForwardedHeaders(request.getHeaders()); + } + else { + urlBuilder = UriComponentsBuilder.fromHttpRequest(request); + } + UriComponents actualUrl = urlBuilder.build(); UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); return (actualUrl.getHost().equals(originUrl.getHost()) && getPort(actualUrl) == getPort(originUrl)); } - private static int getPort(UriComponents component) { - int port = component.getPort(); + private static int getPort(UriComponents uri) { + int port = uri.getPort(); if (port == -1) { - if ("http".equals(component.getScheme()) || "ws".equals(component.getScheme())) { + if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) { port = 80; } - else if ("https".equals(component.getScheme()) || "wss".equals(component.getScheme())) { + else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) { port = 443; } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/MvcUriComponentsBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/MvcUriComponentsBuilder.java index e8ff7ee93ad..25018ff9e63 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/MvcUriComponentsBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/MvcUriComponentsBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -406,7 +406,7 @@ public class MvcUriComponentsBuilder { private static UriComponentsBuilder getBaseUrlToUse(UriComponentsBuilder baseUrl) { if (baseUrl != null) { - return (UriComponentsBuilder) baseUrl.clone(); + return baseUrl.cloneBuilder(); } else { return ServletUriComponentsBuilder.fromCurrentServletMapping(); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java index 74c804f40fc..07abc64b453 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -223,7 +223,7 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder { } @Override - public Object clone() { + public ServletUriComponentsBuilder cloneBuilder() { return new ServletUriComponentsBuilder(this); }