diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index 6c691b2f711..1bc1ef596a9 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -40,20 +40,25 @@ import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UrlPathHelper; /** - * Filter that wraps the request and response in order to override its + * Extract values from "Forwarded" and "X-Forwarded-*" headers in order to wrap + * and override the following from the request and response: * {@link HttpServletRequest#getServerName() getServerName()}, * {@link HttpServletRequest#getServerPort() getServerPort()}, * {@link HttpServletRequest#getScheme() getScheme()}, - * {@link HttpServletRequest#isSecure() isSecure()}, - * {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}, - * methods with values derived from "Forwarded" or "X-Forwarded-*" - * headers. In effect the wrapped request and response reflects the - * client-originated protocol and address. + * {@link HttpServletRequest#isSecure() isSecure()}, and + * {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}. + * In effect the wrapped request and response reflect the client-originated + * protocol and address. + * + *

Note: This filter can also be used in a + * {@link #setRemoveOnly removeOnly} mode where "Forwarded" and "X-Forwarded-*" + * headers are only eliminated without being used. * * @author Rossen Stoyanchev * @author EddĂș MelĂ©ndez * @author Rob Winch * @since 4.3 + * @see https://tools.ietf.org/html/rfc7239 */ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -71,6 +76,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { private final UrlPathHelper pathHelper; + private boolean removeOnly; + public ForwardedHeaderFilter() { this.pathHelper = new UrlPathHelper(); @@ -79,6 +86,17 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { } + /** + * Enables mode in which any "Forwarded" or "X-Forwarded-*" headers are + * removed only and the information in them ignored. + * @param removeOnly whether to discard and ingore forwarded headers + * @since 4.3.9 + */ + public void setRemoveOnly(boolean removeOnly) { + this.removeOnly = removeOnly; + } + + @Override protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { Enumeration names = request.getHeaderNames(); @@ -105,13 +123,67 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - ForwardedHeaderRequestWrapper wrappedRequest = new ForwardedHeaderRequestWrapper(request, this.pathHelper); - ForwardedHeaderResponseWrapper wrappedResponse = new ForwardedHeaderResponseWrapper(response, wrappedRequest); - filterChain.doFilter(wrappedRequest, wrappedResponse); + if (this.removeOnly) { + ForwardedHeaderRemovingRequest theRequest = new ForwardedHeaderRemovingRequest(request); + filterChain.doFilter(theRequest, response); + } + else { + HttpServletRequest theRequest = new ForwardedHeaderExtractingRequest(request, this.pathHelper); + HttpServletResponse theResponse = new ForwardedHeaderExtractingResponse(response, theRequest); + filterChain.doFilter(theRequest, theResponse); + } } - private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper { + /** + * Hide "Forwarded" or "X-Forwarded-*" headers. + */ + private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper { + + private final Map> headers; + + + public ForwardedHeaderRemovingRequest(HttpServletRequest request) { + super(request); + this.headers = initHeaders(request); + } + + private static Map> initHeaders(HttpServletRequest request) { + Map> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH); + Enumeration names = request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + if (!FORWARDED_HEADER_NAMES.contains(name)) { + headers.put(name, Collections.list(request.getHeaders(name))); + } + } + return headers; + } + + // Override header accessors to not expose forwarded headers + + @Override + public String getHeader(String name) { + List value = this.headers.get(name); + return (CollectionUtils.isEmpty(value) ? null : value.get(0)); + } + + @Override + public Enumeration getHeaders(String name) { + List value = this.headers.get(name); + return (Collections.enumeration(value != null ? value : Collections.emptySet())); + } + + @Override + public Enumeration getHeaderNames() { + return Collections.enumeration(this.headers.keySet()); + } + } + + /** + * Extract and use "Forwarded" or "X-Forwarded-*" headers. + */ + private static class ForwardedHeaderExtractingRequest extends ForwardedHeaderRemovingRequest { private final String scheme; @@ -127,9 +199,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { private final String requestUrl; - private final Map> headers; - public ForwardedHeaderRequestWrapper(HttpServletRequest request, UrlPathHelper pathHelper) { + public ForwardedHeaderExtractingRequest(HttpServletRequest request, UrlPathHelper pathHelper) { super(request); HttpRequest httpRequest = new ServletServerHttpRequest(request); @@ -145,7 +216,6 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { this.contextPath = (prefix != null ? prefix : request.getContextPath()); this.requestUri = this.contextPath + pathHelper.getPathWithinApplication(request); this.requestUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port) + this.requestUri; - this.headers = initHeaders(request); } private static String getForwardedPrefix(HttpServletRequest request) { @@ -165,21 +235,6 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { return prefix; } - /** - * Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}. - */ - private static Map> initHeaders(HttpServletRequest request) { - Map> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH); - Enumeration names = request.getHeaderNames(); - while (names.hasMoreElements()) { - String name = names.nextElement(); - if (!FORWARDED_HEADER_NAMES.contains(name)) { - headers.put(name, Collections.list(request.getHeaders(name))); - } - } - return headers; - } - @Override public String getScheme() { return this.scheme; @@ -214,35 +269,18 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { public StringBuffer getRequestURL() { return new StringBuffer(this.requestUrl); } - - // Override header accessors to not expose forwarded headers - - @Override - public String getHeader(String name) { - List value = this.headers.get(name); - return (CollectionUtils.isEmpty(value) ? null : value.get(0)); - } - - @Override - public Enumeration getHeaders(String name) { - List value = this.headers.get(name); - return (Collections.enumeration(value != null ? value : Collections.emptySet())); - } - - @Override - public Enumeration getHeaderNames() { - return Collections.enumeration(this.headers.keySet()); - } } - private static class ForwardedHeaderResponseWrapper extends HttpServletResponseWrapper { + private static class ForwardedHeaderExtractingResponse extends HttpServletResponseWrapper { private static final String FOLDER_SEPARATOR = "/"; + private final HttpServletRequest request; - public ForwardedHeaderResponseWrapper(HttpServletResponse response, HttpServletRequest request) { + + public ForwardedHeaderExtractingResponse(HttpServletResponse response, HttpServletRequest request) { super(response); this.request = request; } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java index 6fe0ee2d55a..2af2f108e37 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java @@ -32,7 +32,10 @@ import org.springframework.mock.web.test.MockFilterChain; import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletResponse; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; /** * Unit tests for {@link ForwardedHeaderFilter}. @@ -239,6 +242,30 @@ public class ForwardedHeaderFilterTests { assertEquals("bar", actual.getHeader("foo")); } + @Test + public void forwardedRequestInRemoveOnlyMode() throws Exception { + this.request.setRequestURI("/mvc-showcase"); + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "84.198.58.199"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.addHeader("foo", "bar"); + + this.filter.setRemoveOnly(true); + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest(); + + assertEquals("http://localhost/mvc-showcase", actual.getRequestURL().toString()); + assertEquals("http", actual.getScheme()); + assertEquals("localhost", actual.getServerName()); + assertEquals(80, actual.getServerPort()); + assertFalse(actual.isSecure()); + + assertNull(actual.getHeader(X_FORWARDED_PROTO)); + assertNull(actual.getHeader(X_FORWARDED_HOST)); + assertNull(actual.getHeader(X_FORWARDED_PORT)); + assertEquals("bar", actual.getHeader("foo")); + } + @Test public void requestUriWithForwardedPrefix() throws Exception { this.request.addHeader(X_FORWARDED_PREFIX, "/prefix");