diff --git a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java index efa21c1d9a0..1d74ec93779 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java @@ -56,7 +56,7 @@ import org.springframework.web.filter.OncePerRequestFilter; @Component @Order(Ordered.HIGHEST_PRECEDENCE) class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer implements -Filter, NonEmbeddedServletContainerFactory { + Filter, NonEmbeddedServletContainerFactory { private static Log logger = LogFactory.getLog(ErrorPageFilter.class); @@ -123,19 +123,20 @@ Filter, NonEmbeddedServletContainerFactory { private void handleErrorStatus(HttpServletRequest request, HttpServletResponse response, int status, String message) - throws ServletException, IOException { + throws ServletException, IOException { String errorPath = getErrorPath(this.statuses, status); if (errorPath == null) { response.sendError(status, message); return; } + response.setStatus(status); setErrorAttributes(request, status, message); request.getRequestDispatcher(errorPath).forward(request, response); } private void handleException(HttpServletRequest request, HttpServletResponse response, ErrorWrapperResponse wrapped, Throwable ex) - throws IOException, ServletException { + throws IOException, ServletException { Class type = ex.getClass(); String errorPath = getErrorPath(type); if (errorPath == null) { diff --git a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java index d29127ffffd..32075143940 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java @@ -63,6 +63,28 @@ public class ErrorPageFilterTests { assertTrue(this.response.isCommitted()); } + @Test + public void unauthorizedWithErrorPath() throws Exception { + this.filter.addErrorPages(new ErrorPage("/error")); + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + ((HttpServletResponse) response).sendError(401, "UNAUTHORIZED"); + super.doFilter(request, response); + } + }; + this.filter.doFilter(this.request, this.response, this.chain); + assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request)); + HttpServletResponseWrapper wrapper = (HttpServletResponseWrapper) this.chain + .getResponse(); + assertThat(wrapper.getResponse(), equalTo((ServletResponse) this.response)); + assertTrue(this.response.isCommitted()); + assertThat(wrapper.getStatus(), equalTo(401)); + // The real response has to be 401 as well... + assertThat(this.response.getStatus(), equalTo(401)); + } + @Test public void responseCommitted() throws Exception { this.filter.addErrorPages(new ErrorPage("/error"));