diff --git a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java index 5204f33a39..594812d860 100644 --- a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java @@ -40,6 +40,7 @@ import org.springframework.security.web.firewall.HttpStatusRequestRejectedHandle import org.springframework.security.web.firewall.RequestRejectedException; import org.springframework.security.web.firewall.RequestRejectedHandler; import org.springframework.security.web.firewall.StrictHttpFirewall; +import org.springframework.security.web.util.ThrowableAnalyzer; import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -157,6 +158,8 @@ public class FilterChainProxy extends GenericFilterBean { private RequestRejectedHandler requestRejectedHandler = new HttpStatusRequestRejectedHandler(); + private ThrowableAnalyzer throwableAnalyzer = new ThrowableAnalyzer(); + public FilterChainProxy() { } @@ -185,8 +188,15 @@ public class FilterChainProxy extends GenericFilterBean { request.setAttribute(FILTER_APPLIED, Boolean.TRUE); doFilterInternal(request, response, chain); } - catch (RequestRejectedException ex) { - this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex); + catch (Exception ex) { + Throwable[] causeChain = this.throwableAnalyzer.determineCauseChain(ex); + Throwable requestRejectedException = this.throwableAnalyzer + .getFirstThrowableOfType(RequestRejectedException.class, causeChain); + if (!(requestRejectedException instanceof RequestRejectedException)) { + throw ex; + } + this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, + (RequestRejectedException) requestRejectedException); } finally { this.securityContextHolderStrategy.clearContext(); diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index cd272305cb..fe0b77e581 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -49,6 +49,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; @@ -261,4 +262,18 @@ public class FilterChainProxyTests { verify(rjh).handle(eq(this.request), eq(this.response), eq((requestRejectedException))); } + @Test + public void requestRejectedHandlerIsCalledIfFirewallThrowsWrappedRequestRejectedException() throws Exception { + HttpFirewall fw = mock(HttpFirewall.class); + RequestRejectedHandler rjh = mock(RequestRejectedHandler.class); + this.fcp.setFirewall(fw); + this.fcp.setRequestRejectedHandler(rjh); + RequestRejectedException requestRejectedException = new RequestRejectedException("Contains illegal chars"); + ServletException servletException = new ServletException(requestRejectedException); + given(fw.getFirewalledRequest(this.request)).willReturn(mock(FirewalledRequest.class)); + willThrow(servletException).given(this.chain).doFilter(any(), any()); + this.fcp.doFilter(this.request, this.response, this.chain); + verify(rjh).handle(eq(this.request), eq(this.response), eq((requestRejectedException))); + } + }