diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/servlet/support/ErrorPageFilter.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/servlet/support/ErrorPageFilter.java index cc40c344916..decd47dc3e0 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/servlet/support/ErrorPageFilter.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/servlet/support/ErrorPageFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2019 the original author or authors. + * Copyright 2012-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -137,7 +137,10 @@ public class ErrorPageFilter implements Filter, ErrorPageRegistry { catch (Throwable ex) { Throwable exceptionToHandle = ex; if (ex instanceof NestedServletException) { - exceptionToHandle = ((NestedServletException) ex).getRootCause(); + Throwable rootCause = ((NestedServletException) ex).getRootCause(); + if (rootCause != null) { + exceptionToHandle = rootCause; + } } handleException(request, response, wrapped, exceptionToHandle); response.flushBuffer(); diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/support/ErrorPageFilterTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/support/ErrorPageFilterTests.java index adcc8bb38eb..b0f183dc433 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/support/ErrorPageFilterTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/servlet/support/ErrorPageFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2019 the original author or authors. + * Copyright 2012-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,6 +43,7 @@ import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockRequestDispatcher; +import org.springframework.web.bind.MissingServletRequestParameterException; import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; import org.springframework.web.context.request.async.WebAsyncManager; @@ -387,6 +388,30 @@ class ErrorPageFilterTests { assertThat(this.response.getForwardedUrl()).isEqualTo("/500"); } + @Test + void nestedServletExceptionWithNoCause() throws Exception { + this.filter.addErrorPages(new ErrorPage(MissingServletRequestParameterException.class, "/500")); + this.chain = new TestFilterChain((request, response, chain) -> { + chain.call(); + throw new MissingServletRequestParameterException("test", "string"); + }); + this.filter.doFilter(this.request, this.response, this.chain); + assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus()).isEqualTo(500); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_STATUS_CODE)).isEqualTo(500); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) + .isEqualTo("Required string parameter 'test' is not present"); + Map requestAttributes = getAttributesForDispatch("/500"); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE)) + .isEqualTo(MissingServletRequestParameterException.class); + assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION)) + .isInstanceOf(MissingServletRequestParameterException.class); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)).isNull(); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull(); + assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)).isEqualTo("/test/path"); + assertThat(this.response.isCommitted()).isTrue(); + assertThat(this.response.getForwardedUrl()).isEqualTo("/500"); + } + @Test void whenErrorIsSentAndWriterIsFlushedErrorIsSentToTheClient() throws Exception { this.chain = new TestFilterChain((request, response, chain) -> {