diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java
index e0af31c56d6..5881c36fbe3 100644
--- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java
+++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java
@@ -31,9 +31,11 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpRequest;
import org.springframework.http.server.ServletServerHttpRequest;
+import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
+import org.springframework.web.util.UrlPathHelper;
/**
@@ -61,6 +63,28 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
}
+ private ContextPathHelper contextPathHelper;
+
+
+
+ /**
+ * Configure a contextPath value that will replace the contextPath of
+ * proxy-forwarded requests.
+ *
+ *
This is useful when external clients are not aware of the application
+ * context path. However a proxy forwards the request to a URL that includes
+ * a contextPath.
+ *
+ * @param contextPath the context path; the given value will be sanitized to
+ * ensure it starts with a '/' but does not end with one, or if the context
+ * path is empty (default, root context) it is left as-is.
+ */
+ public void setContextPath(String contextPath) {
+ Assert.notNull(contextPath, "'contextPath' must not be null");
+ this.contextPathHelper = new ContextPathHelper(contextPath);
+ }
+
+
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
Enumeration headerNames = request.getHeaderNames();
@@ -87,7 +111,7 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
- filterChain.doFilter(new ForwardedHeaderRequestWrapper(request), response);
+ filterChain.doFilter(new ForwardedHeaderRequestWrapper(request, this.contextPathHelper), response);
}
@@ -105,12 +129,16 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
private final int port;
+ private final String contextPath;
+
+ private final String requestUri;
+
private final StringBuffer requestUrl;
private final Map> headers;
- public ForwardedHeaderRequestWrapper(HttpServletRequest request) {
+ public ForwardedHeaderRequestWrapper(HttpServletRequest request, ContextPathHelper pathHelper) {
super(request);
HttpRequest httpRequest = new ServletServerHttpRequest(request);
@@ -121,7 +149,11 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
this.secure = "https".equals(scheme);
this.host = uriComponents.getHost();
this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
- this.requestUrl = initRequestUrl(this.scheme, this.host, port, request.getRequestURI());
+
+ this.contextPath = (pathHelper != null ? pathHelper.getContextPath(request) : request.getContextPath());
+ this.requestUri = (pathHelper != null ? pathHelper.getRequestUri(request) : request.getRequestURI());
+ this.requestUrl = initRequestUrl(this.scheme, this.host, port, this.requestUri);
+
this.headers = initHeaders(request);
}
@@ -170,6 +202,16 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
return this.secure;
}
+ @Override
+ public String getContextPath() {
+ return this.contextPath;
+ }
+
+ @Override
+ public String getRequestURI() {
+ return this.requestUri;
+ }
+
@Override
public StringBuffer getRequestURL() {
return this.requestUrl;
@@ -195,4 +237,50 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
}
}
+
+ private static class ContextPathHelper {
+
+ private final String contextPath;
+
+ private final UrlPathHelper urlPathHelper;
+
+
+ public ContextPathHelper(String contextPath) {
+ Assert.notNull(contextPath);
+ this.contextPath = sanitizeContextPath(contextPath);
+ this.urlPathHelper = new UrlPathHelper();
+ this.urlPathHelper.setUrlDecode(false);
+ this.urlPathHelper.setRemoveSemicolonContent(false);
+ }
+
+ private static String sanitizeContextPath(String contextPath) {
+ contextPath = contextPath.trim();
+ if (contextPath.isEmpty()) {
+ return contextPath;
+ }
+ if (contextPath.equals("/")) {
+ return "/";
+ }
+ if (contextPath.charAt(0) != '/') {
+ contextPath = "/" + contextPath;
+ }
+ while (contextPath.endsWith("/")) {
+ contextPath = contextPath.substring(0, contextPath.length() -1);
+ }
+ return contextPath;
+ }
+
+ public String getContextPath(HttpServletRequest request) {
+ return this.contextPath;
+ }
+
+ public String getRequestUri(HttpServletRequest request) {
+ String pathWithinApplication = this.urlPathHelper.getPathWithinApplication(request);
+ if (this.contextPath.equals("/") && pathWithinApplication.startsWith("/")) {
+ return pathWithinApplication;
+ }
+ return this.contextPath + pathWithinApplication;
+ }
+ }
+
}
diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java
index 62f20c71bb9..7907c66bd41 100644
--- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java
+++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java
@@ -15,10 +15,12 @@
*/
package org.springframework.web.filter;
+import java.io.IOException;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
+import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.test.MockFilterChain;
@@ -38,6 +40,98 @@ public class ForwardedHeaderFilterTests {
private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter();
+ private MockHttpServletRequest request;
+
+ private MockFilterChain filterChain;
+
+
+ @Before
+ public void setUp() throws Exception {
+ this.request = new MockHttpServletRequest();
+ this.request.setScheme("http");
+ this.request.setServerName("localhost");
+ this.request.setServerPort(80);
+ this.filterChain = new MockFilterChain(new HttpServlet() {});
+ }
+
+
+ @Test(expected = IllegalArgumentException.class)
+ public void contextPathNull() {
+ this.filter.setContextPath(null);
+ }
+
+ @Test
+ public void contextPathEmpty() throws Exception {
+ this.filter.setContextPath("");
+ assertEquals("", filterAndGetContextPath());
+ }
+
+ @Test
+ public void contextPathWithExtraSpaces() throws Exception {
+ this.filter.setContextPath(" /foo ");
+ assertEquals("/foo", filterAndGetContextPath());
+ }
+
+ @Test
+ public void contextPathWithNoLeadingSlash() throws Exception {
+ this.filter.setContextPath("foo");
+ assertEquals("/foo", filterAndGetContextPath());
+ }
+
+ @Test
+ public void contextPathWithTrailingSlash() throws Exception {
+ this.filter.setContextPath("/foo/bar/");
+ assertEquals("/foo/bar", filterAndGetContextPath());
+ }
+
+ @Test
+ public void contextPathWithTrailingSlashes() throws Exception {
+ this.filter.setContextPath("/foo/bar/baz///");
+ assertEquals("/foo/bar/baz", filterAndGetContextPath());
+ }
+
+ @Test
+ public void requestUri() throws Exception {
+ this.filter.setContextPath("/");
+ this.request.setContextPath("/app");
+ this.request.setRequestURI("/app/path");
+ HttpServletRequest actual = filterAndGetWrappedRequest();
+
+ assertEquals("/", actual.getContextPath());
+ assertEquals("/path", actual.getRequestURI());
+ }
+
+ @Test
+ public void requestUriWithTrailingSlash() throws Exception {
+ this.filter.setContextPath("/");
+ this.request.setContextPath("/app");
+ this.request.setRequestURI("/app/path/");
+ HttpServletRequest actual = filterAndGetWrappedRequest();
+
+ assertEquals("/", actual.getContextPath());
+ assertEquals("/path/", actual.getRequestURI());
+ }
+ @Test
+ public void requestUriEqualsContextPath() throws Exception {
+ this.filter.setContextPath("/");
+ this.request.setContextPath("/app");
+ this.request.setRequestURI("/app");
+ HttpServletRequest actual = filterAndGetWrappedRequest();
+
+ assertEquals("/", actual.getContextPath());
+ assertEquals("/", actual.getRequestURI());
+ }
+
+ @Test
+ public void requestUriRootUrl() throws Exception {
+ this.filter.setContextPath("/");
+ this.request.setContextPath("/app");
+ this.request.setRequestURI("/app/");
+ HttpServletRequest actual = filterAndGetWrappedRequest();
+
+ assertEquals("/", actual.getContextPath());
+ assertEquals("/", actual.getRequestURI());
+ }
@Test
public void shouldFilter() throws Exception {
@@ -54,19 +148,14 @@ public class ForwardedHeaderFilterTests {
@Test
public void forwardedRequest() throws Exception {
- MockHttpServletRequest request = new MockHttpServletRequest();
- request.setScheme("http");
- request.setServerName("localhost");
- request.setServerPort(80);
- request.setRequestURI("/mvc-showcase");
- request.addHeader("X-Forwarded-Proto", "https");
- request.addHeader("X-Forwarded-Host", "84.198.58.199");
- request.addHeader("X-Forwarded-Port", "443");
- request.addHeader("foo", "bar");
-
- MockFilterChain chain = new MockFilterChain(new HttpServlet() {});
- this.filter.doFilter(request, new MockHttpServletResponse(), chain);
- HttpServletRequest actual = (HttpServletRequest) chain.getRequest();
+ this.request.setRequestURI("/mvc-showcase");
+ this.request.addHeader("X-Forwarded-Proto", "https");
+ this.request.addHeader("X-Forwarded-Host", "84.198.58.199");
+ this.request.addHeader("X-Forwarded-Port", "443");
+ this.request.addHeader("foo", "bar");
+
+ this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
+ HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest();
assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString());
assertEquals("https", actual.getScheme());
@@ -81,11 +170,20 @@ public class ForwardedHeaderFilterTests {
}
+ private String filterAndGetContextPath() throws ServletException, IOException {
+ return filterAndGetWrappedRequest().getContextPath();
+ }
+
+ private HttpServletRequest filterAndGetWrappedRequest() throws ServletException, IOException {
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ this.filter.doFilterInternal(this.request, response, this.filterChain);
+ return (HttpServletRequest) this.filterChain.getRequest();
+ }
+
private void testShouldFilter(String headerName) throws ServletException {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader(headerName, "1");
assertFalse(this.filter.shouldNotFilter(request));
}
-
}