From 8090a52f5ceb201b22798b3a2136c4e064b26647 Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Thu, 30 Nov 2023 13:10:06 +0000 Subject: [PATCH] ForwardedHeaderFilter supports ERROR requestUri attribute Closes gh-30828 --- .../web/filter/ForwardedHeaderFilter.java | 38 +++++++++++++++++++ .../filter/ForwardedHeaderFilterTests.java | 23 ++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index 3b4bc005d60..a03c188cac4 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -27,6 +27,7 @@ import java.util.function.Supplier; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; import jakarta.servlet.http.HttpServletResponse; @@ -37,12 +38,14 @@ import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.StringUtils; import org.springframework.web.util.ForwardedHeaderUtils; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UrlPathHelper; +import org.springframework.web.util.WebUtils; /** * Extract values from "Forwarded" and "X-Forwarded-*" headers, wrap the request @@ -312,6 +315,15 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { public int getRemotePort() { return (this.remoteAddress != null ? this.remoteAddress.getPort() : super.getRemotePort()); } + + @SuppressWarnings("DataFlowIssue") + @Override + public Object getAttribute(String name) { + if (name.equals(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE)) { + return this.forwardedPrefixExtractor.getErrorRequestUri(); + } + return super.getAttribute(name); + } } @@ -419,6 +431,17 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { this.requestUrl = initRequestUrl(); } } + + @Nullable + public String getErrorRequestUri() { + HttpServletRequest request = this.delegate.get(); + String requestUri = (String) request.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE); + if (this.forwardedPrefix == null || requestUri == null) { + return requestUri; + } + ErrorPathRequest errorRequest = new ErrorPathRequest(request); + return this.forwardedPrefix + UrlPathHelper.rawPathInstance.getPathWithinApplication(errorRequest); + } } @@ -473,4 +496,19 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { } } + + private static class ErrorPathRequest extends HttpServletRequestWrapper { + + ErrorPathRequest(ServletRequest request) { + super((HttpServletRequest) request); + } + + @Override + public String getRequestURI() { + String requestUri = (String) getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE); + Assert.isTrue(requestUri != null, "Expected ERROR requestUri attribute"); + return requestUri; + } + } + } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java index bf4a4c722ff..b421d0c9c61 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java @@ -35,6 +35,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.springframework.web.testfixture.servlet.MockFilterChain; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletResponse; +import org.springframework.web.util.WebUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -187,7 +188,7 @@ public class ForwardedHeaderFilterTests { } @Test // SPR-16983 - public void forwardedRequestWithServletForward() throws Exception { + public void forwardedRequestWithForwardDispatch() throws Exception { this.request.setRequestURI("/foo"); this.request.addHeader(X_FORWARDED_PROTO, "https"); this.request.addHeader(X_FORWARDED_HOST, "www.mycompany.example"); @@ -208,6 +209,26 @@ public class ForwardedHeaderFilterTests { assertThat(actual.getRequestURL().toString()).isEqualTo("https://www.mycompany.example/bar"); } + @Test // gh-30828 + public void forwardedRequestWithErrorDispatch() throws Exception { + this.request.setRequestURI("/foo"); + this.request.setDispatcherType(DispatcherType.ERROR); + this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_HOST, "www.mycompany.example"); + this.request.addHeader(X_FORWARDED_PORT, "443"); + this.request.addHeader(X_FORWARDED_PREFIX, "/app"); + this.request.setAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE, "/foo"); + + this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + + HttpServletRequest wrappedRequest = (HttpServletRequest) this.filterChain.getRequest(); + + assertThat(wrappedRequest).isNotNull(); + assertThat(wrappedRequest.getRequestURI()).isEqualTo("/app/foo"); + assertThat(wrappedRequest.getRequestURL().toString()).isEqualTo("https://www.mycompany.example/app/foo"); + assertThat(wrappedRequest.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE)).isEqualTo("/app/foo"); + } + @Nested class ForwardedPrefix {