diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java index 037b7382db7..5ad1d24715d 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java @@ -35,7 +35,7 @@ import org.springframework.web.util.WebUtils; */ public class ServletUriComponentsBuilder extends UriComponentsBuilder { - private String servletRequestURI; + private String originalPath; /** @@ -86,7 +86,6 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder { */ public static ServletUriComponentsBuilder fromRequestUri(HttpServletRequest request) { ServletUriComponentsBuilder builder = fromRequest(request); - builder.pathFromRequest(request); builder.replaceQuery(null); return builder; } @@ -99,6 +98,7 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder { String scheme = request.getScheme(); String host = request.getServerName(); int port = request.getServerPort(); + String path = request.getRequestURI(); String hostHeader = request.getHeader("X-Forwarded-Host"); if (StringUtils.hasText(hostHeader)) { @@ -125,13 +125,18 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder { scheme = protocolHeader; } + String prefix = request.getHeader("X-Forwarded-Prefix"); + if (StringUtils.hasText(prefix)) { + path = prefix + path; + } + ServletUriComponentsBuilder builder = new ServletUriComponentsBuilder(); builder.scheme(scheme); builder.host(host); if (scheme.equals("http") && port != 80 || scheme.equals("https") && port != 443) { builder.port(port); } - builder.pathFromRequest(request); + builder.initPath(path); builder.query(request.getQueryString()); return builder; } @@ -180,9 +185,9 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder { return servletRequest; } - private void pathFromRequest(HttpServletRequest request) { - this.servletRequestURI = request.getRequestURI(); - replacePath(request.getRequestURI()); + private void initPath(String path) { + this.originalPath = path; + replacePath(path); } /** @@ -190,27 +195,27 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder { * requestURI}. This method must be invoked before any calls to {@link #path(String)} * or {@link #pathSegment(String...)}. *
-	 * 	// GET http://foo.com/rest/books/6.json
 	 *
-	 *	ServletUriComponentsBuilder builder = ServletUriComponentsBuilder.fromRequestUri(this.request);
-	 *	String ext = builder.removePathExtension();
-	 *	String uri = builder.path("/pages/1.{ext}").buildAndExpand(ext).toUriString();
+	 * GET http://foo.com/rest/books/6.json
 	 *
-	 * 	assertEquals("http://foo.com/rest/books/6/pages/1.json", result);
+	 * ServletUriComponentsBuilder builder = ServletUriComponentsBuilder.fromRequestUri(this.request);
+	 * String ext = builder.removePathExtension();
+	 * String uri = builder.path("/pages/1.{ext}").buildAndExpand(ext).toUriString();
+	 * assertEquals("http://foo.com/rest/books/6/pages/1.json", result);
 	 * 
* @return the removed path extension for possible re-use, or {@code null} * @since 4.0 */ public String removePathExtension() { String extension = null; - if (this.servletRequestURI != null) { - String filename = WebUtils.extractFullFilenameFromUrlPath(this.servletRequestURI); + if (this.originalPath != null) { + String filename = WebUtils.extractFullFilenameFromUrlPath(this.originalPath); extension = StringUtils.getFilenameExtension(filename); if (!StringUtils.isEmpty(extension)) { - int end = this.servletRequestURI.length() - (extension.length() + 1); - replacePath(this.servletRequestURI.substring(0, end)); + int end = this.originalPath.length() - (extension.length() + 1); + replacePath(this.originalPath.substring(0, end)); } - this.servletRequestURI = null; + this.originalPath = null; } return extension; } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/support/ServletUriComponentsBuilderTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/support/ServletUriComponentsBuilderTests.java index b3da135eff2..9953bc79e50 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/support/ServletUriComponentsBuilderTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/support/ServletUriComponentsBuilderTests.java @@ -150,6 +150,24 @@ public class ServletUriComponentsBuilderTests { assertEquals("should have used the default port of the forwarded request", -1, result.getPort()); } + @Test + public void fromRequestWithForwardedPrefix() { + this.request.setRequestURI("/bar"); + this.request.addHeader("X-Forwarded-Prefix", "/foo"); + UriComponents result = ServletUriComponentsBuilder.fromRequest(request).build(); + + assertEquals("http://localhost/foo/bar", result.toUriString()); + } + + @Test + public void fromRequestWithForwardedPrefixTrailingSlash() { + this.request.setRequestURI("/bar"); + this.request.addHeader("X-Forwarded-Prefix", "/foo/"); + UriComponents result = ServletUriComponentsBuilder.fromRequest(request).build(); + + assertEquals("http://localhost/foo/bar", result.toUriString()); + } + @Test public void fromContextPath() { request.setRequestURI("/mvc-showcase/data/param");