diff --git a/spring-boot/src/main/java/org/springframework/boot/web/support/ErrorPageFilter.java b/spring-boot/src/main/java/org/springframework/boot/web/support/ErrorPageFilter.java index 08c262f30d9..5ff0f2c40c9 100644 --- a/spring-boot/src/main/java/org/springframework/boot/web/support/ErrorPageFilter.java +++ b/spring-boot/src/main/java/org/springframework/boot/web/support/ErrorPageFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2016 the original author or authors. + * Copyright 2012-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -183,6 +183,8 @@ public class ErrorPageFilter implements Filter, ErrorPageRegistry { response.reset(); response.sendError(500, ex.getMessage()); request.getRequestDispatcher(path).forward(request, response); + request.removeAttribute(ERROR_EXCEPTION); + request.removeAttribute(ERROR_EXCEPTION_TYPE); } private String getDescription(HttpServletRequest request) { diff --git a/spring-boot/src/test/java/org/springframework/boot/web/support/ErrorPageFilterTests.java b/spring-boot/src/test/java/org/springframework/boot/web/support/ErrorPageFilterTests.java index 8f0b8d33828..53182442d70 100644 --- a/spring-boot/src/test/java/org/springframework/boot/web/support/ErrorPageFilterTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/web/support/ErrorPageFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2016 the original author or authors. + * Copyright 2012-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,9 @@ package org.springframework.boot.web.support; import java.io.IOException; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; import javax.servlet.RequestDispatcher; import javax.servlet.ServletException; @@ -35,6 +38,7 @@ import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockRequestDispatcher; import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; import org.springframework.web.context.request.async.WebAsyncManager; @@ -57,8 +61,7 @@ public class ErrorPageFilterTests { private ErrorPageFilter filter = new ErrorPageFilter(); - private MockHttpServletRequest request = new MockHttpServletRequest("GET", - "/test/path"); + private DispatchRecordingMockHttpServletRequest request = new DispatchRecordingMockHttpServletRequest(); private MockHttpServletResponse response = new MockHttpServletResponse(); @@ -261,8 +264,14 @@ public class ErrorPageFilterTests { .isEqualTo(500); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) .isEqualTo("BAD"); - assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + Map requestAttributes = getAttributesForDispatch("/500"); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE)) .isEqualTo(RuntimeException.class); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION)) + .isInstanceOf(RuntimeException.class); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + .isNull(); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull(); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) .isEqualTo("/test/path"); assertThat(this.response.isCommitted()).isTrue(); @@ -318,8 +327,14 @@ public class ErrorPageFilterTests { .isEqualTo(500); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) .isEqualTo("BAD"); - assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + Map requestAttributes = getAttributesForDispatch("/500"); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE)) .isEqualTo(IllegalStateException.class); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION)) + .isInstanceOf(IllegalStateException.class); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + .isNull(); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull(); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) .isEqualTo("/test/path"); assertThat(this.response.isCommitted()).isTrue(); @@ -492,8 +507,14 @@ public class ErrorPageFilterTests { .isEqualTo(500); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) .isEqualTo("BAD"); - assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + Map requestAttributes = getAttributesForDispatch("/500"); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE)) .isEqualTo(RuntimeException.class); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION)) + .isInstanceOf(RuntimeException.class); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + .isNull(); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull(); assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) .isEqualTo("/test/path"); assertThat(this.response.isCommitted()).isTrue(); @@ -510,4 +531,60 @@ public class ErrorPageFilterTests { asyncManager.startDeferredResultProcessing(result); } + private Map getAttributesForDispatch(String path) { + return this.request.getDispatcher(path).getRequestAttributes(); + } + + private static final class DispatchRecordingMockHttpServletRequest + extends MockHttpServletRequest { + + private final Map dispatchers = new HashMap(); + + private DispatchRecordingMockHttpServletRequest() { + super("GET", "/test/path"); + } + + @Override + public RequestDispatcher getRequestDispatcher(String path) { + AttributeCapturingRequestDispatcher dispatcher = new AttributeCapturingRequestDispatcher( + path); + this.dispatchers.put(path, dispatcher); + return dispatcher; + } + + private AttributeCapturingRequestDispatcher getDispatcher(String path) { + return this.dispatchers.get(path); + } + + private static final class AttributeCapturingRequestDispatcher + extends MockRequestDispatcher { + + private final Map requestAttributes = new HashMap(); + + private AttributeCapturingRequestDispatcher(String resource) { + super(resource); + } + + @Override + public void forward(ServletRequest request, ServletResponse response) { + captureAttributes(request); + super.forward(request, response); + } + + private void captureAttributes(ServletRequest request) { + Enumeration names = request.getAttributeNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + this.requestAttributes.put(name, request.getAttribute(name)); + } + } + + private Map getRequestAttributes() { + return this.requestAttributes; + } + + } + + } + }