From feeec344e530aeb32c0d9d36588b7fb821ac0e9e Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 23 Oct 2018 11:52:09 -0400 Subject: [PATCH] ForwardedHeaderFilter works with Servlet FORWARD Issue: SPR-16983 --- .../web/filter/ForwardedHeaderFilter.java | 140 ++++++++++++++---- .../filter/ForwardedHeaderFilterTests.java | 24 +++ 2 files changed, 135 insertions(+), 29 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index 0af5bb3670a..a3b4f9ec191 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -219,11 +220,7 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { private final int port; - private final String contextPath; - - private final String requestUri; - - private final String requestUrl; + private final ForwardedPrefixExtractor forwardedPrefixExtractor; ForwardedHeaderExtractingRequest(HttpServletRequest request, UrlPathHelper pathHelper) { @@ -238,28 +235,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { this.host = uriComponents.getHost(); this.port = (port == -1 ? (this.secure ? 443 : 80) : port); - String prefix = getForwardedPrefix(request); - this.contextPath = (prefix != null ? prefix : request.getContextPath()); - this.requestUri = this.contextPath + pathHelper.getPathWithinApplication(request); - this.requestUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port) + this.requestUri; - } - - @Nullable - private static String getForwardedPrefix(HttpServletRequest request) { - String prefix = null; - Enumeration names = request.getHeaderNames(); - while (names.hasMoreElements()) { - String name = names.nextElement(); - if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) { - prefix = request.getHeader(name); - } - } - if (prefix != null) { - while (prefix.endsWith("/")) { - prefix = prefix.substring(0, prefix.length() - 1); - } - } - return prefix; + String baseUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port); + Supplier delegateRequest = () -> (HttpServletRequest) getRequest(); + this.forwardedPrefixExtractor = new ForwardedPrefixExtractor(delegateRequest, pathHelper, baseUrl); } @@ -287,18 +265,122 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @Override public String getContextPath() { - return this.contextPath; + return this.forwardedPrefixExtractor.getContextPath(); } @Override public String getRequestURI() { - return this.requestUri; + return this.forwardedPrefixExtractor.getRequestUri(); } @Override public StringBuffer getRequestURL() { + return this.forwardedPrefixExtractor.getRequestUrl(); + } + } + + + /** + * Responsible for the contextPath, requestURI, and requestURL with forwarded + * headers in mind, and also taking into account changes to the path of the + * underlying delegate request (e.g. on a Servlet FORWARD). + */ + private static class ForwardedPrefixExtractor { + + private final Supplier delegate; + + private final UrlPathHelper pathHelper; + + private final String baseUrl; + + private String actualRequestUri; + + @Nullable + private final String forwardedPrefix; + + @Nullable + private String requestUri; + + private String requestUrl; + + + /** + * Constructor with required information. + * @param delegateRequest supplier for the current + * {@link HttpServletRequestWrapper#getRequest() delegate request} which + * may change during a forward (e.g. Tocat. + * @param pathHelper the path helper instance + * @param baseUrl the host, scheme, and port based on forwarded headers + */ + public ForwardedPrefixExtractor( + Supplier delegateRequest, UrlPathHelper pathHelper, String baseUrl) { + + this.delegate = delegateRequest; + this.pathHelper = pathHelper; + this.baseUrl = baseUrl; + this.actualRequestUri = delegateRequest.get().getRequestURI(); + + this.forwardedPrefix = initForwardedPrefix(delegateRequest.get()); + this.requestUri = initRequestUri(); + this.requestUrl = initRequestUrl(); // Keep the order: depends on requestUri + } + + @Nullable + private static String initForwardedPrefix(HttpServletRequest request) { + String result = null; + Enumeration names = request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) { + result = request.getHeader(name); + } + } + if (result != null) { + while (result.endsWith("/")) { + result = result.substring(0, result.length() - 1); + } + } + return result; + } + + @Nullable + private String initRequestUri() { + if (this.forwardedPrefix != null) { + return this.forwardedPrefix + this.pathHelper.getPathWithinApplication(this.delegate.get()); + } + return null; + } + + private String initRequestUrl() { + return this.baseUrl + (this.requestUri != null ? this.requestUri : this.delegate.get().getRequestURI()); + } + + + public String getContextPath() { + return this.forwardedPrefix == null ? this.delegate.get().getContextPath() : this.forwardedPrefix; + } + + public String getRequestUri() { + if (this.requestUri == null) { + return this.delegate.get().getRequestURI(); + } + recalculatePathsIfNecesary(); + return this.requestUri; + } + + public StringBuffer getRequestUrl() { + recalculatePathsIfNecesary(); return new StringBuffer(this.requestUrl); } + + private void recalculatePathsIfNecesary() { + if (!this.actualRequestUri.equals(this.delegate.get().getRequestURI())) { + // Underlying path change (e.g. Servlet FORWARD). + this.actualRequestUri = this.delegate.get().getRequestURI(); + this.requestUri = initRequestUri(); + this.requestUrl = initRequestUrl(); // Keep the order: depends on requestUri + } + } } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java index dcaa5446b7f..5fd0370d355 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java @@ -17,7 +17,9 @@ package org.springframework.web.filter; import java.io.IOException; +import java.net.URI; import java.util.Enumeration; +import javax.servlet.DispatcherType; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -308,6 +310,28 @@ public class ForwardedHeaderFilterTests { assertEquals("bar", actual.getHeader("foo")); } + @Test // SPR-16983 + public void forwardedRequestWithServletForward() throws Exception { + this.request.setRequestURI("/foo"); + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "www.mycompany.com"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + HttpServletRequest wrappedRequest = (HttpServletRequest) this.filterChain.getRequest(); + + this.request.setDispatcherType(DispatcherType.FORWARD); + this.request.setRequestURI("/bar"); + this.filterChain.reset(); + + this.filter.doFilter(wrappedRequest, new MockHttpServletResponse(), this.filterChain); + HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest(); + + assertNotNull(actual); + assertEquals("/bar", actual.getRequestURI()); + assertEquals("https://www.mycompany.com/bar", actual.getRequestURL().toString()); + } + @Test public void requestUriWithForwardedPrefix() throws Exception { this.request.addHeader(X_FORWARDED_PREFIX, "/prefix");