From 36e2dd90a7840fe82d30016ec5e7adc0c054eb7a Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 2 Mar 2016 18:37:22 -0500 Subject: [PATCH] Support contextPath override in ForwardedHeaderFilter Issue: SPR-13614 --- .../web/filter/ForwardedHeaderFilter.java | 94 ++++++++++++- .../filter/ForwardedHeaderFilterTests.java | 126 ++++++++++++++++-- 2 files changed, 203 insertions(+), 17 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index e0af31c56d6..5881c36fbe3 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -31,9 +31,11 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.http.HttpRequest; import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UrlPathHelper; /** @@ -61,6 +63,28 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { } + private ContextPathHelper contextPathHelper; + + + + /** + * Configure a contextPath value that will replace the contextPath of + * proxy-forwarded requests. + * + *

This is useful when external clients are not aware of the application + * context path. However a proxy forwards the request to a URL that includes + * a contextPath. + * + * @param contextPath the context path; the given value will be sanitized to + * ensure it starts with a '/' but does not end with one, or if the context + * path is empty (default, root context) it is left as-is. + */ + public void setContextPath(String contextPath) { + Assert.notNull(contextPath, "'contextPath' must not be null"); + this.contextPathHelper = new ContextPathHelper(contextPath); + } + + @Override protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { Enumeration headerNames = request.getHeaderNames(); @@ -87,7 +111,7 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - filterChain.doFilter(new ForwardedHeaderRequestWrapper(request), response); + filterChain.doFilter(new ForwardedHeaderRequestWrapper(request, this.contextPathHelper), response); } @@ -105,12 +129,16 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { private final int port; + private final String contextPath; + + private final String requestUri; + private final StringBuffer requestUrl; private final Map> headers; - public ForwardedHeaderRequestWrapper(HttpServletRequest request) { + public ForwardedHeaderRequestWrapper(HttpServletRequest request, ContextPathHelper pathHelper) { super(request); HttpRequest httpRequest = new ServletServerHttpRequest(request); @@ -121,7 +149,11 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { this.secure = "https".equals(scheme); this.host = uriComponents.getHost(); this.port = (port == -1 ? (this.secure ? 443 : 80) : port); - this.requestUrl = initRequestUrl(this.scheme, this.host, port, request.getRequestURI()); + + this.contextPath = (pathHelper != null ? pathHelper.getContextPath(request) : request.getContextPath()); + this.requestUri = (pathHelper != null ? pathHelper.getRequestUri(request) : request.getRequestURI()); + this.requestUrl = initRequestUrl(this.scheme, this.host, port, this.requestUri); + this.headers = initHeaders(request); } @@ -170,6 +202,16 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { return this.secure; } + @Override + public String getContextPath() { + return this.contextPath; + } + + @Override + public String getRequestURI() { + return this.requestUri; + } + @Override public StringBuffer getRequestURL() { return this.requestUrl; @@ -195,4 +237,50 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { } } + + private static class ContextPathHelper { + + private final String contextPath; + + private final UrlPathHelper urlPathHelper; + + + public ContextPathHelper(String contextPath) { + Assert.notNull(contextPath); + this.contextPath = sanitizeContextPath(contextPath); + this.urlPathHelper = new UrlPathHelper(); + this.urlPathHelper.setUrlDecode(false); + this.urlPathHelper.setRemoveSemicolonContent(false); + } + + private static String sanitizeContextPath(String contextPath) { + contextPath = contextPath.trim(); + if (contextPath.isEmpty()) { + return contextPath; + } + if (contextPath.equals("/")) { + return "/"; + } + if (contextPath.charAt(0) != '/') { + contextPath = "/" + contextPath; + } + while (contextPath.endsWith("/")) { + contextPath = contextPath.substring(0, contextPath.length() -1); + } + return contextPath; + } + + public String getContextPath(HttpServletRequest request) { + return this.contextPath; + } + + public String getRequestUri(HttpServletRequest request) { + String pathWithinApplication = this.urlPathHelper.getPathWithinApplication(request); + if (this.contextPath.equals("/") && pathWithinApplication.startsWith("/")) { + return pathWithinApplication; + } + return this.contextPath + pathWithinApplication; + } + } + } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java index 62f20c71bb9..7907c66bd41 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java @@ -15,10 +15,12 @@ */ package org.springframework.web.filter; +import java.io.IOException; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; +import org.junit.Before; import org.junit.Test; import org.springframework.mock.web.test.MockFilterChain; @@ -38,6 +40,98 @@ public class ForwardedHeaderFilterTests { private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter(); + private MockHttpServletRequest request; + + private MockFilterChain filterChain; + + + @Before + public void setUp() throws Exception { + this.request = new MockHttpServletRequest(); + this.request.setScheme("http"); + this.request.setServerName("localhost"); + this.request.setServerPort(80); + this.filterChain = new MockFilterChain(new HttpServlet() {}); + } + + + @Test(expected = IllegalArgumentException.class) + public void contextPathNull() { + this.filter.setContextPath(null); + } + + @Test + public void contextPathEmpty() throws Exception { + this.filter.setContextPath(""); + assertEquals("", filterAndGetContextPath()); + } + + @Test + public void contextPathWithExtraSpaces() throws Exception { + this.filter.setContextPath(" /foo "); + assertEquals("/foo", filterAndGetContextPath()); + } + + @Test + public void contextPathWithNoLeadingSlash() throws Exception { + this.filter.setContextPath("foo"); + assertEquals("/foo", filterAndGetContextPath()); + } + + @Test + public void contextPathWithTrailingSlash() throws Exception { + this.filter.setContextPath("/foo/bar/"); + assertEquals("/foo/bar", filterAndGetContextPath()); + } + + @Test + public void contextPathWithTrailingSlashes() throws Exception { + this.filter.setContextPath("/foo/bar/baz///"); + assertEquals("/foo/bar/baz", filterAndGetContextPath()); + } + + @Test + public void requestUri() throws Exception { + this.filter.setContextPath("/"); + this.request.setContextPath("/app"); + this.request.setRequestURI("/app/path"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("/", actual.getContextPath()); + assertEquals("/path", actual.getRequestURI()); + } + + @Test + public void requestUriWithTrailingSlash() throws Exception { + this.filter.setContextPath("/"); + this.request.setContextPath("/app"); + this.request.setRequestURI("/app/path/"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("/", actual.getContextPath()); + assertEquals("/path/", actual.getRequestURI()); + } + @Test + public void requestUriEqualsContextPath() throws Exception { + this.filter.setContextPath("/"); + this.request.setContextPath("/app"); + this.request.setRequestURI("/app"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("/", actual.getContextPath()); + assertEquals("/", actual.getRequestURI()); + } + + @Test + public void requestUriRootUrl() throws Exception { + this.filter.setContextPath("/"); + this.request.setContextPath("/app"); + this.request.setRequestURI("/app/"); + HttpServletRequest actual = filterAndGetWrappedRequest(); + + assertEquals("/", actual.getContextPath()); + assertEquals("/", actual.getRequestURI()); + } @Test public void shouldFilter() throws Exception { @@ -54,19 +148,14 @@ public class ForwardedHeaderFilterTests { @Test public void forwardedRequest() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setScheme("http"); - request.setServerName("localhost"); - request.setServerPort(80); - request.setRequestURI("/mvc-showcase"); - request.addHeader("X-Forwarded-Proto", "https"); - request.addHeader("X-Forwarded-Host", "84.198.58.199"); - request.addHeader("X-Forwarded-Port", "443"); - request.addHeader("foo", "bar"); - - MockFilterChain chain = new MockFilterChain(new HttpServlet() {}); - this.filter.doFilter(request, new MockHttpServletResponse(), chain); - HttpServletRequest actual = (HttpServletRequest) chain.getRequest(); + this.request.setRequestURI("/mvc-showcase"); + this.request.addHeader("X-Forwarded-Proto", "https"); + this.request.addHeader("X-Forwarded-Host", "84.198.58.199"); + this.request.addHeader("X-Forwarded-Port", "443"); + this.request.addHeader("foo", "bar"); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest(); assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString()); assertEquals("https", actual.getScheme()); @@ -81,11 +170,20 @@ public class ForwardedHeaderFilterTests { } + private String filterAndGetContextPath() throws ServletException, IOException { + return filterAndGetWrappedRequest().getContextPath(); + } + + private HttpServletRequest filterAndGetWrappedRequest() throws ServletException, IOException { + MockHttpServletResponse response = new MockHttpServletResponse(); + this.filter.doFilterInternal(this.request, response, this.filterChain); + return (HttpServletRequest) this.filterChain.getRequest(); + } + private void testShouldFilter(String headerName) throws ServletException { MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader(headerName, "1"); assertFalse(this.filter.shouldNotFilter(request)); } - }