diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java
index 6c691b2f711..1bc1ef596a9 100644
--- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java
+++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java
@@ -40,20 +40,25 @@ import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UrlPathHelper;
/**
- * Filter that wraps the request and response in order to override its
+ * Extract values from "Forwarded" and "X-Forwarded-*" headers in order to wrap
+ * and override the following from the request and response:
* {@link HttpServletRequest#getServerName() getServerName()},
* {@link HttpServletRequest#getServerPort() getServerPort()},
* {@link HttpServletRequest#getScheme() getScheme()},
- * {@link HttpServletRequest#isSecure() isSecure()},
- * {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)},
- * methods with values derived from "Forwarded" or "X-Forwarded-*"
- * headers. In effect the wrapped request and response reflects the
- * client-originated protocol and address.
+ * {@link HttpServletRequest#isSecure() isSecure()}, and
+ * {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}.
+ * In effect the wrapped request and response reflect the client-originated
+ * protocol and address.
+ *
+ *
Note: This filter can also be used in a
+ * {@link #setRemoveOnly removeOnly} mode where "Forwarded" and "X-Forwarded-*"
+ * headers are only eliminated without being used.
*
* @author Rossen Stoyanchev
* @author EddĂș MelĂ©ndez
* @author Rob Winch
* @since 4.3
+ * @see https://tools.ietf.org/html/rfc7239
*/
public class ForwardedHeaderFilter extends OncePerRequestFilter {
@@ -71,6 +76,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
private final UrlPathHelper pathHelper;
+ private boolean removeOnly;
+
public ForwardedHeaderFilter() {
this.pathHelper = new UrlPathHelper();
@@ -79,6 +86,17 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
}
+ /**
+ * Enables mode in which any "Forwarded" or "X-Forwarded-*" headers are
+ * removed only and the information in them ignored.
+ * @param removeOnly whether to discard and ingore forwarded headers
+ * @since 4.3.9
+ */
+ public void setRemoveOnly(boolean removeOnly) {
+ this.removeOnly = removeOnly;
+ }
+
+
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
Enumeration names = request.getHeaderNames();
@@ -105,13 +123,67 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
- ForwardedHeaderRequestWrapper wrappedRequest = new ForwardedHeaderRequestWrapper(request, this.pathHelper);
- ForwardedHeaderResponseWrapper wrappedResponse = new ForwardedHeaderResponseWrapper(response, wrappedRequest);
- filterChain.doFilter(wrappedRequest, wrappedResponse);
+ if (this.removeOnly) {
+ ForwardedHeaderRemovingRequest theRequest = new ForwardedHeaderRemovingRequest(request);
+ filterChain.doFilter(theRequest, response);
+ }
+ else {
+ HttpServletRequest theRequest = new ForwardedHeaderExtractingRequest(request, this.pathHelper);
+ HttpServletResponse theResponse = new ForwardedHeaderExtractingResponse(response, theRequest);
+ filterChain.doFilter(theRequest, theResponse);
+ }
}
- private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper {
+ /**
+ * Hide "Forwarded" or "X-Forwarded-*" headers.
+ */
+ private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper {
+
+ private final Map> headers;
+
+
+ public ForwardedHeaderRemovingRequest(HttpServletRequest request) {
+ super(request);
+ this.headers = initHeaders(request);
+ }
+
+ private static Map> initHeaders(HttpServletRequest request) {
+ Map> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH);
+ Enumeration names = request.getHeaderNames();
+ while (names.hasMoreElements()) {
+ String name = names.nextElement();
+ if (!FORWARDED_HEADER_NAMES.contains(name)) {
+ headers.put(name, Collections.list(request.getHeaders(name)));
+ }
+ }
+ return headers;
+ }
+
+ // Override header accessors to not expose forwarded headers
+
+ @Override
+ public String getHeader(String name) {
+ List value = this.headers.get(name);
+ return (CollectionUtils.isEmpty(value) ? null : value.get(0));
+ }
+
+ @Override
+ public Enumeration getHeaders(String name) {
+ List value = this.headers.get(name);
+ return (Collections.enumeration(value != null ? value : Collections.emptySet()));
+ }
+
+ @Override
+ public Enumeration getHeaderNames() {
+ return Collections.enumeration(this.headers.keySet());
+ }
+ }
+
+ /**
+ * Extract and use "Forwarded" or "X-Forwarded-*" headers.
+ */
+ private static class ForwardedHeaderExtractingRequest extends ForwardedHeaderRemovingRequest {
private final String scheme;
@@ -127,9 +199,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
private final String requestUrl;
- private final Map> headers;
- public ForwardedHeaderRequestWrapper(HttpServletRequest request, UrlPathHelper pathHelper) {
+ public ForwardedHeaderExtractingRequest(HttpServletRequest request, UrlPathHelper pathHelper) {
super(request);
HttpRequest httpRequest = new ServletServerHttpRequest(request);
@@ -145,7 +216,6 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
this.contextPath = (prefix != null ? prefix : request.getContextPath());
this.requestUri = this.contextPath + pathHelper.getPathWithinApplication(request);
this.requestUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port) + this.requestUri;
- this.headers = initHeaders(request);
}
private static String getForwardedPrefix(HttpServletRequest request) {
@@ -165,21 +235,6 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
return prefix;
}
- /**
- * Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}.
- */
- private static Map> initHeaders(HttpServletRequest request) {
- Map> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH);
- Enumeration names = request.getHeaderNames();
- while (names.hasMoreElements()) {
- String name = names.nextElement();
- if (!FORWARDED_HEADER_NAMES.contains(name)) {
- headers.put(name, Collections.list(request.getHeaders(name)));
- }
- }
- return headers;
- }
-
@Override
public String getScheme() {
return this.scheme;
@@ -214,35 +269,18 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
public StringBuffer getRequestURL() {
return new StringBuffer(this.requestUrl);
}
-
- // Override header accessors to not expose forwarded headers
-
- @Override
- public String getHeader(String name) {
- List value = this.headers.get(name);
- return (CollectionUtils.isEmpty(value) ? null : value.get(0));
- }
-
- @Override
- public Enumeration getHeaders(String name) {
- List value = this.headers.get(name);
- return (Collections.enumeration(value != null ? value : Collections.emptySet()));
- }
-
- @Override
- public Enumeration getHeaderNames() {
- return Collections.enumeration(this.headers.keySet());
- }
}
- private static class ForwardedHeaderResponseWrapper extends HttpServletResponseWrapper {
+ private static class ForwardedHeaderExtractingResponse extends HttpServletResponseWrapper {
private static final String FOLDER_SEPARATOR = "/";
+
private final HttpServletRequest request;
- public ForwardedHeaderResponseWrapper(HttpServletResponse response, HttpServletRequest request) {
+
+ public ForwardedHeaderExtractingResponse(HttpServletResponse response, HttpServletRequest request) {
super(response);
this.request = request;
}
diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java
index 6fe0ee2d55a..2af2f108e37 100644
--- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java
+++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java
@@ -32,7 +32,10 @@ import org.springframework.mock.web.test.MockFilterChain;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
/**
* Unit tests for {@link ForwardedHeaderFilter}.
@@ -239,6 +242,30 @@ public class ForwardedHeaderFilterTests {
assertEquals("bar", actual.getHeader("foo"));
}
+ @Test
+ public void forwardedRequestInRemoveOnlyMode() throws Exception {
+ this.request.setRequestURI("/mvc-showcase");
+ this.request.addHeader(X_FORWARDED_PROTO, "https");
+ this.request.addHeader(X_FORWARDED_HOST, "84.198.58.199");
+ this.request.addHeader(X_FORWARDED_PORT, "443");
+ this.request.addHeader("foo", "bar");
+
+ this.filter.setRemoveOnly(true);
+ this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
+ HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest();
+
+ assertEquals("http://localhost/mvc-showcase", actual.getRequestURL().toString());
+ assertEquals("http", actual.getScheme());
+ assertEquals("localhost", actual.getServerName());
+ assertEquals(80, actual.getServerPort());
+ assertFalse(actual.isSecure());
+
+ assertNull(actual.getHeader(X_FORWARDED_PROTO));
+ assertNull(actual.getHeader(X_FORWARDED_HOST));
+ assertNull(actual.getHeader(X_FORWARDED_PORT));
+ assertEquals("bar", actual.getHeader("foo"));
+ }
+
@Test
public void requestUriWithForwardedPrefix() throws Exception {
this.request.addHeader(X_FORWARDED_PREFIX, "/prefix");