From 13e040c06e18021e4ba081dcb0d0d68e58b03274 Mon Sep 17 00:00:00 2001 From: Dave Syer Date: Mon, 3 Mar 2014 09:50:52 +0000 Subject: [PATCH] Add ErrorWrapperEmbeddedServletContainerFactory for error pages in WARs Error pages are a feature of the servlet spec but there is no Java API for registering them in the spec. This filter works around that by accepting error page registrations from Spring Boot's EmbeddedServletContainerCustomizer (any beans of that type in the context will be applied to this container). In addition the ErrorController interface was enhanced to provide callers the option to suppress logging. Fixes gh-410 --- .../endpoint/mvc/ManagementErrorEndpoint.java | 2 +- .../actuate/trace/WebRequestTraceFilter.java | 4 +- .../actuate/web/BasicErrorController.java | 14 +- .../boot/actuate/web/ErrorController.java | 4 +- ...rapperEmbeddedServletContainerFactory.java | 200 ++++++++++++++++++ .../web/SpringBootServletInitializer.java | 2 + ...rEmbeddedServletContainerFactoryTests.java | 107 ++++++++++ 7 files changed, 326 insertions(+), 7 deletions(-) create mode 100644 spring-boot/src/main/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactory.java create mode 100644 spring-boot/src/test/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactoryTests.java diff --git a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/endpoint/mvc/ManagementErrorEndpoint.java b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/endpoint/mvc/ManagementErrorEndpoint.java index f54eb450d73..bf6cce4edb7 100644 --- a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/endpoint/mvc/ManagementErrorEndpoint.java +++ b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/endpoint/mvc/ManagementErrorEndpoint.java @@ -51,7 +51,7 @@ public class ManagementErrorEndpoint implements MvcEndpoint { @ResponseBody public Map invoke() { RequestAttributes attributes = RequestContextHolder.currentRequestAttributes(); - return this.controller.extract(attributes, false); + return this.controller.extract(attributes, false, true); } @Override diff --git a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/trace/WebRequestTraceFilter.java b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/trace/WebRequestTraceFilter.java index 2fe115367ec..7ed773ce750 100644 --- a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/trace/WebRequestTraceFilter.java +++ b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/trace/WebRequestTraceFilter.java @@ -36,6 +36,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.boot.actuate.web.BasicErrorController; import org.springframework.core.Ordered; +import org.springframework.web.context.request.ServletRequestAttributes; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; @@ -159,7 +160,8 @@ public class WebRequestTraceFilter implements Filter, Ordered { .getAttribute("javax.servlet.error.exception"); if (error != null) { if (this.errorController != null) { - trace.put("error", this.errorController.error(request)); + trace.put("error", this.errorController.extract( + new ServletRequestAttributes(request), true, false)); } } return trace; diff --git a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/BasicErrorController.java b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/BasicErrorController.java index 8fb01de0d37..aee6e96d529 100644 --- a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/BasicErrorController.java +++ b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/BasicErrorController.java @@ -72,11 +72,13 @@ public class BasicErrorController implements ErrorController { public Map error(HttpServletRequest request) { ServletRequestAttributes attributes = new ServletRequestAttributes(request); String trace = request.getParameter("trace"); - return extract(attributes, trace != null && !"false".equals(trace.toLowerCase())); + return extract(attributes, trace != null && !"false".equals(trace.toLowerCase()), + true); } @Override - public Map extract(RequestAttributes attributes, boolean trace) { + public Map extract(RequestAttributes attributes, boolean trace, + boolean log) { Map map = new LinkedHashMap(); map.put("timestamp", new Date()); try { @@ -105,7 +107,9 @@ public class BasicErrorController implements ErrorController { stackTrace.flush(); map.put("trace", stackTrace.toString()); } - this.logger.error(error); + if (log) { + this.logger.error(error); + } } else { Object message = attributes.getAttribute("javax.servlet.error.message", @@ -117,7 +121,9 @@ public class BasicErrorController implements ErrorController { catch (Exception ex) { map.put(ERROR_KEY, ex.getClass().getName()); map.put("message", ex.getMessage()); - this.logger.error(ex); + if (log) { + this.logger.error(ex); + } return map; } } diff --git a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/ErrorController.java b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/ErrorController.java index 597b79a1a6c..f7d1009871a 100644 --- a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/ErrorController.java +++ b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/web/ErrorController.java @@ -38,8 +38,10 @@ public interface ErrorController { * Extract a useful model of the error from the request attributes. * @param attributes the request attributes * @param trace flag to indicate that stack trace information should be included + * @param log flag to indicate that an error should be logged * @return a model containing error messages and codes etc. */ - public Map extract(RequestAttributes attributes, boolean trace); + public Map extract(RequestAttributes attributes, boolean trace, + boolean log); } diff --git a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactory.java b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactory.java new file mode 100644 index 00000000000..dc94846abe9 --- /dev/null +++ b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactory.java @@ -0,0 +1,200 @@ +/* + * Copyright 2012-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.context.web; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +import org.springframework.boot.context.embedded.AbstractEmbeddedServletContainerFactory; +import org.springframework.boot.context.embedded.EmbeddedServletContainer; +import org.springframework.boot.context.embedded.EmbeddedServletContainerCustomizer; +import org.springframework.boot.context.embedded.EmbeddedServletContainerException; +import org.springframework.boot.context.embedded.EmbeddedServletContainerFactory; +import org.springframework.boot.context.embedded.ErrorPage; +import org.springframework.boot.context.embedded.ServletContextInitializer; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; + +/** + * A special {@link EmbeddedServletContainerFactory} for non-embedded applications (i.e. + * deployed WAR files). It registers error pages and handles application errors by + * filtering requests and forwarding to the error pages instead of letting the container + * handle them. Error pages are a feature of the servlet spec but there is no Java API for + * registering them in the spec. This filter works around that by accepting error page + * registrations from Spring Boot's {@link EmbeddedServletContainerCustomizer} (any beans + * of that type in the context will be applied to this container). + * + * @author Dave Syer + * + */ +@Component +@Order(Ordered.HIGHEST_PRECEDENCE) +public class ErrorWrapperEmbeddedServletContainerFactory extends + AbstractEmbeddedServletContainerFactory implements Filter { + + private String global; + + private Map statuses = new HashMap(); + + private Map, String> exceptions = new HashMap, String>(); + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + String errorPath; + ErrorWrapperResponse wrapped = new ErrorWrapperResponse( + (HttpServletResponse) response); + try { + chain.doFilter(request, wrapped); + int status = wrapped.getStatus(); + if (status >= 400) { + errorPath = this.statuses.containsKey(status) ? this.statuses.get(status) + : this.global; + if (errorPath != null) { + request.setAttribute("javax.servlet.error.status_code", status); + request.setAttribute("javax.servlet.error.message", + wrapped.getMessage()); + ((HttpServletRequest) request).getRequestDispatcher(errorPath) + .forward(request, response); + } + else { + ((HttpServletResponse) response).sendError(status, + wrapped.getMessage()); + } + } + } + catch (Throwable e) { + Class cls = e.getClass(); + errorPath = this.exceptions.containsKey(cls) ? this.exceptions.get(cls) + : this.global; + if (errorPath != null) { + request.setAttribute("javax.servlet.error.status_code", 500); + request.setAttribute("javax.servlet.error.exception", e); + request.setAttribute("javax.servlet.error.message", e.getMessage()); + wrapped.sendError(500, e.getMessage()); + ((HttpServletRequest) request).getRequestDispatcher(errorPath).forward( + request, response); + } + else { + rethrow(e); + } + } + } + + private void rethrow(Throwable e) throws IOException, ServletException { + if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } + if (e instanceof Error) { + throw (Error) e; + } + if (e instanceof IOException) { + throw (IOException) e; + } + if (e instanceof ServletException) { + throw (ServletException) e; + } + throw new IllegalStateException("Unidentified Exception", e); + } + + @Override + public EmbeddedServletContainer getEmbeddedServletContainer( + ServletContextInitializer... initializers) { + return new EmbeddedServletContainer() { + + @Override + public void start() throws EmbeddedServletContainerException { + } + + @Override + public void stop() throws EmbeddedServletContainerException { + } + + @Override + public int getPort() { + return -1; + } + }; + } + + @Override + public void addErrorPages(ErrorPage... errorPages) { + for (ErrorPage errorPage : errorPages) { + if (errorPage.isGlobal()) { + this.global = errorPage.getPath(); + } + else if (errorPage.getStatus() != null) { + this.statuses.put(errorPage.getStatus().value(), errorPage.getPath()); + } + else { + this.exceptions.put(errorPage.getException(), errorPage.getPath()); + } + } + } + + @Override + public void destroy() { + } + + private static class ErrorWrapperResponse extends HttpServletResponseWrapper { + + private int status; + private String message; + + public ErrorWrapperResponse(HttpServletResponse response) { + super(response); + } + + @Override + public void sendError(int status) throws IOException { + sendError(status, null); + } + + @Override + public void sendError(int status, String message) throws IOException { + this.status = status; + this.message = message; + } + + @Override + public int getStatus() { + return this.status; + } + + public String getMessage() { + return this.message; + } + + } + +} diff --git a/spring-boot/src/main/java/org/springframework/boot/context/web/SpringBootServletInitializer.java b/spring-boot/src/main/java/org/springframework/boot/context/web/SpringBootServletInitializer.java index d3d9bbd05cf..12fbcd94c6f 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/web/SpringBootServletInitializer.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/web/SpringBootServletInitializer.java @@ -84,6 +84,8 @@ public abstract class SpringBootServletInitializer implements WebApplicationInit servletContext)); application.contextClass(AnnotationConfigEmbeddedWebApplicationContext.class); application = configure(application); + // Ensure error pages ar registered + application.sources(ErrorWrapperEmbeddedServletContainerFactory.class); return (WebApplicationContext) application.run(); } diff --git a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactoryTests.java b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactoryTests.java new file mode 100644 index 00000000000..9be037c5d19 --- /dev/null +++ b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorWrapperEmbeddedServletContainerFactoryTests.java @@ -0,0 +1,107 @@ +/* + * Copyright 2012-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.context.web; + +import java.io.IOException; + +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +import org.junit.Test; +import org.springframework.boot.context.embedded.ErrorPage; +import org.springframework.http.HttpStatus; +import org.springframework.mock.web.MockFilterChain; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import static org.junit.Assert.assertEquals; + +/** + * @author Dave Syer + */ +public class ErrorWrapperEmbeddedServletContainerFactoryTests { + + private ErrorWrapperEmbeddedServletContainerFactory filter = new ErrorWrapperEmbeddedServletContainerFactory(); + private MockHttpServletRequest request = new MockHttpServletRequest(); + private MockHttpServletResponse response = new MockHttpServletResponse(); + private MockFilterChain chain = new MockFilterChain(); + + @Test + public void notAnError() throws Exception { + this.filter.doFilter(this.request, this.response, this.chain); + assertEquals(this.request, this.chain.getRequest()); + assertEquals(this.response, + ((HttpServletResponseWrapper) this.chain.getResponse()).getResponse()); + } + + @Test + public void globalError() throws Exception { + this.filter.addErrorPages(new ErrorPage("/error")); + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + ((HttpServletResponse) response).sendError(400, "BAD"); + super.doFilter(request, response); + } + }; + this.filter.doFilter(this.request, this.response, this.chain); + assertEquals(400, + ((HttpServletResponseWrapper) this.chain.getResponse()).getStatus()); + assertEquals(400, this.request.getAttribute("javax.servlet.error.status_code")); + assertEquals("BAD", this.request.getAttribute("javax.servlet.error.message")); + } + + @Test + public void statusError() throws Exception { + this.filter.addErrorPages(new ErrorPage(HttpStatus.BAD_REQUEST, "/400")); + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + ((HttpServletResponse) response).sendError(400, "BAD"); + super.doFilter(request, response); + } + }; + this.filter.doFilter(this.request, this.response, this.chain); + assertEquals(400, + ((HttpServletResponseWrapper) this.chain.getResponse()).getStatus()); + assertEquals(400, this.request.getAttribute("javax.servlet.error.status_code")); + assertEquals("BAD", this.request.getAttribute("javax.servlet.error.message")); + } + + @Test + public void exceptionError() throws Exception { + this.filter.addErrorPages(new ErrorPage(RuntimeException.class, "/500")); + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + super.doFilter(request, response); + throw new RuntimeException("BAD"); + } + }; + this.filter.doFilter(this.request, this.response, this.chain); + assertEquals(500, + ((HttpServletResponseWrapper) this.chain.getResponse()).getStatus()); + assertEquals(500, this.request.getAttribute("javax.servlet.error.status_code")); + assertEquals("BAD", this.request.getAttribute("javax.servlet.error.message")); + } +}