From 9247cd9c361bcf9fb51d3423d3f67fa59bda3feb Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Wed, 8 Feb 2017 11:45:59 +0000 Subject: [PATCH] Don't let standalone Tomcat render its error page after redirect Previously, if the configured error controller responded with a redirect to an error caused by an exception, standalone Tomcat would render its default error page for the original exception. This occurred because ErrorPageFilter sets the javax.servlet.error.exception request attribute prior to dispatching to the error controller and then does not clear it. As the request unwinds, Tomcat's ErrorReportValve notices that the attribute is set and renders an error page for the exception that is the attribute's value. This commit updates ErrorPageFilter to remove the javax.servlet.error.exception and javax.servlet.error.exception_type attributes upon successful completion of a forward to the error controller. This prevents Tomcat from rendering an error page for an exception that has already been handled by the error controller. Closes gh-7920 --- .../boot/web/support/ErrorPageFilter.java | 4 +- .../web/support/ErrorPageFilterTests.java | 89 +++++++++++++++++-- 2 files changed, 86 insertions(+), 7 deletions(-) diff --git a/spring-boot/src/main/java/org/springframework/boot/web/support/ErrorPageFilter.java b/spring-boot/src/main/java/org/springframework/boot/web/support/ErrorPageFilter.java index 08c262f30d9..5ff0f2c40c9 100644 --- a/spring-boot/src/main/java/org/springframework/boot/web/support/ErrorPageFilter.java +++ b/spring-boot/src/main/java/org/springframework/boot/web/support/ErrorPageFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2016 the original author or authors. + * Copyright 2012-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -183,6 +183,8 @@ public class ErrorPageFilter implements Filter, ErrorPageRegistry { response.reset(); response.sendError(500, ex.getMessage()); request.getRequestDispatcher(path).forward(request, response); + request.removeAttribute(ERROR_EXCEPTION); + request.removeAttribute(ERROR_EXCEPTION_TYPE); } private String getDescription(HttpServletRequest request) { diff --git a/spring-boot/src/test/java/org/springframework/boot/web/support/ErrorPageFilterTests.java b/spring-boot/src/test/java/org/springframework/boot/web/support/ErrorPageFilterTests.java index 8f0b8d33828..53182442d70 100644 --- a/spring-boot/src/test/java/org/springframework/boot/web/support/ErrorPageFilterTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/web/support/ErrorPageFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2016 the original author or authors. + * Copyright 2012-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,9 @@ package org.springframework.boot.web.support; import java.io.IOException; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; import javax.servlet.RequestDispatcher; import javax.servlet.ServletException; @@ -35,6 +38,7 @@ import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockRequestDispatcher; import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; import org.springframework.web.context.request.async.WebAsyncManager; @@ -57,8 +61,7 @@ public class ErrorPageFilterTests { private ErrorPageFilter filter = new ErrorPageFilter(); - private MockHttpServletRequest request = new MockHttpServletRequest("GET", - "/test/path"); + private DispatchRecordingMockHttpServletRequest request = new DispatchRecordingMockHttpServletRequest(); private MockHttpServletResponse response = new MockHttpServletResponse(); @@ -261,8 +264,14 @@ public class ErrorPageFilterTests { .isEqualTo(500); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) .isEqualTo("BAD"); - assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + Map requestAttributes = getAttributesForDispatch("/500"); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE)) .isEqualTo(RuntimeException.class); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION)) + .isInstanceOf(RuntimeException.class); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + .isNull(); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull(); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) .isEqualTo("/test/path"); assertThat(this.response.isCommitted()).isTrue(); @@ -318,8 +327,14 @@ public class ErrorPageFilterTests { .isEqualTo(500); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) .isEqualTo("BAD"); - assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + Map requestAttributes = getAttributesForDispatch("/500"); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE)) .isEqualTo(IllegalStateException.class); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION)) + .isInstanceOf(IllegalStateException.class); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + .isNull(); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull(); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) .isEqualTo("/test/path"); assertThat(this.response.isCommitted()).isTrue(); @@ -492,8 +507,14 @@ public class ErrorPageFilterTests { .isEqualTo(500); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) .isEqualTo("BAD"); - assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + Map requestAttributes = getAttributesForDispatch("/500"); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE)) .isEqualTo(RuntimeException.class); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION)) + .isInstanceOf(RuntimeException.class); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + .isNull(); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull(); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) .isEqualTo("/test/path"); assertThat(this.response.isCommitted()).isTrue(); @@ -510,4 +531,60 @@ public class ErrorPageFilterTests { asyncManager.startDeferredResultProcessing(result); } + private Map getAttributesForDispatch(String path) { + return this.request.getDispatcher(path).getRequestAttributes(); + } + + private static final class DispatchRecordingMockHttpServletRequest + extends MockHttpServletRequest { + + private final Map dispatchers = new HashMap(); + + private DispatchRecordingMockHttpServletRequest() { + super("GET", "/test/path"); + } + + @Override + public RequestDispatcher getRequestDispatcher(String path) { + AttributeCapturingRequestDispatcher dispatcher = new AttributeCapturingRequestDispatcher( + path); + this.dispatchers.put(path, dispatcher); + return dispatcher; + } + + private AttributeCapturingRequestDispatcher getDispatcher(String path) { + return this.dispatchers.get(path); + } + + private static final class AttributeCapturingRequestDispatcher + extends MockRequestDispatcher { + + private final Map requestAttributes = new HashMap(); + + private AttributeCapturingRequestDispatcher(String resource) { + super(resource); + } + + @Override + public void forward(ServletRequest request, ServletResponse response) { + captureAttributes(request); + super.forward(request, response); + } + + private void captureAttributes(ServletRequest request) { + Enumeration names = request.getAttributeNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + this.requestAttributes.put(name, request.getAttribute(name)); + } + } + + private Map getRequestAttributes() { + return this.requestAttributes; + } + + } + + } + }