@ -20,11 +20,13 @@ import java.io.IOException;
@@ -20,11 +20,13 @@ import java.io.IOException;
import java.util.Enumeration ;
import java.util.HashMap ;
import java.util.Map ;
import java.util.concurrent.atomic.AtomicBoolean ;
import javax.servlet.RequestDispatcher ;
import javax.servlet.ServletException ;
import javax.servlet.ServletRequest ;
import javax.servlet.ServletResponse ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import javax.servlet.http.HttpServletResponseWrapper ;
@ -65,7 +67,8 @@ public class ErrorPageFilterTests {
@@ -65,7 +67,8 @@ public class ErrorPageFilterTests {
private MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
private MockFilterChain chain = new MockFilterChain ( ) ;
private MockFilterChain chain = new TestFilterChain ( ( request , response , chain ) - > {
} ) ;
@Rule
public OutputCapture output = new OutputCapture ( ) ;
@ -82,15 +85,11 @@ public class ErrorPageFilterTests {
@@ -82,15 +85,11 @@ public class ErrorPageFilterTests {
@Test
public void notAnErrorButNotOK ( ) throws Exception {
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
( ( HttpServletResponse ) response ) . setStatus ( 201 ) ;
super . doFilter ( request , response ) ;
response . flushBuffer ( ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
response . setStatus ( 201 ) ;
chain . call ( ) ;
response . flushBuffer ( ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( ( ( HttpServletResponse ) this . chain . getResponse ( ) ) . getStatus ( ) )
. isEqualTo ( 201 ) ;
@ -102,14 +101,8 @@ public class ErrorPageFilterTests {
@@ -102,14 +101,8 @@ public class ErrorPageFilterTests {
@Test
public void unauthorizedWithErrorPath ( ) throws Exception {
this . filter . addErrorPages ( new ErrorPage ( "/error" ) ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
( ( HttpServletResponse ) response ) . sendError ( 401 , "UNAUTHORIZED" ) ;
super . doFilter ( request , response ) ;
}
} ;
this . chain = new TestFilterChain (
( request , response , chain ) - > response . sendError ( 401 , "UNAUTHORIZED" ) ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . chain . getRequest ( ) ) . isEqualTo ( this . request ) ;
HttpServletResponseWrapper wrapper = ( HttpServletResponseWrapper ) this . chain
@ -126,14 +119,8 @@ public class ErrorPageFilterTests {
@@ -126,14 +119,8 @@ public class ErrorPageFilterTests {
public void responseCommitted ( ) throws Exception {
this . filter . addErrorPages ( new ErrorPage ( "/error" ) ) ;
this . response . setCommitted ( true ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
( ( HttpServletResponse ) response ) . sendError ( 400 , "BAD" ) ;
super . doFilter ( request , response ) ;
}
} ;
this . chain = new TestFilterChain (
( request , response , chain ) - > response . sendError ( 400 , "BAD" ) ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . chain . getRequest ( ) ) . isEqualTo ( this . request ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getResponse ( ) )
@ -146,14 +133,8 @@ public class ErrorPageFilterTests {
@@ -146,14 +133,8 @@ public class ErrorPageFilterTests {
@Test
public void responseUncommittedWithoutErrorPage ( ) throws Exception {
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
( ( HttpServletResponse ) response ) . sendError ( 400 , "BAD" ) ;
super . doFilter ( request , response ) ;
}
} ;
this . chain = new TestFilterChain (
( request , response , chain ) - > response . sendError ( 400 , "BAD" ) ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . chain . getRequest ( ) ) . isEqualTo ( this . request ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getResponse ( ) )
@ -166,15 +147,10 @@ public class ErrorPageFilterTests {
@@ -166,15 +147,10 @@ public class ErrorPageFilterTests {
@Test
public void oncePerRequest ( ) throws Exception {
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
( ( HttpServletResponse ) response ) . sendError ( 400 , "BAD" ) ;
assertThat ( request . getAttribute ( "FILTER.FILTERED" ) ) . isNotNull ( ) ;
super . doFilter ( request , response ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
response . sendError ( 400 , "BAD" ) ;
assertThat ( request . getAttribute ( "FILTER.FILTERED" ) ) . isNotNull ( ) ;
} ) ;
this . filter . init ( new MockFilterConfig ( "FILTER" ) ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
}
@ -182,14 +158,8 @@ public class ErrorPageFilterTests {
@@ -182,14 +158,8 @@ public class ErrorPageFilterTests {
@Test
public void globalError ( ) throws Exception {
this . filter . addErrorPages ( new ErrorPage ( "/error" ) ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
( ( HttpServletResponse ) response ) . sendError ( 400 , "BAD" ) ;
super . doFilter ( request , response ) ;
}
} ;
this . chain = new TestFilterChain (
( request , response , chain ) - > response . sendError ( 400 , "BAD" ) ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getStatus ( ) )
. isEqualTo ( 400 ) ;
@ -206,14 +176,8 @@ public class ErrorPageFilterTests {
@@ -206,14 +176,8 @@ public class ErrorPageFilterTests {
@Test
public void statusError ( ) throws Exception {
this . filter . addErrorPages ( new ErrorPage ( HttpStatus . BAD_REQUEST , "/400" ) ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
( ( HttpServletResponse ) response ) . sendError ( 400 , "BAD" ) ;
super . doFilter ( request , response ) ;
}
} ;
this . chain = new TestFilterChain (
( request , response , chain ) - > response . sendError ( 400 , "BAD" ) ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getStatus ( ) )
. isEqualTo ( 400 ) ;
@ -230,15 +194,10 @@ public class ErrorPageFilterTests {
@@ -230,15 +194,10 @@ public class ErrorPageFilterTests {
@Test
public void statusErrorWithCommittedResponse ( ) throws Exception {
this . filter . addErrorPages ( new ErrorPage ( HttpStatus . BAD_REQUEST , "/400" ) ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
( ( HttpServletResponse ) response ) . sendError ( 400 , "BAD" ) ;
response . flushBuffer ( ) ;
super . doFilter ( request , response ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
response . sendError ( 400 , "BAD" ) ;
response . flushBuffer ( ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getStatus ( ) )
. isEqualTo ( 400 ) ;
@ -249,14 +208,10 @@ public class ErrorPageFilterTests {
@@ -249,14 +208,10 @@ public class ErrorPageFilterTests {
@Test
public void exceptionError ( ) throws Exception {
this . filter . addErrorPages ( new ErrorPage ( RuntimeException . class , "/500" ) ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
super . doFilter ( request , response ) ;
throw new RuntimeException ( "BAD" ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
chain . call ( ) ;
throw new RuntimeException ( "BAD" ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getStatus ( ) )
. isEqualTo ( 500 ) ;
@ -281,29 +236,20 @@ public class ErrorPageFilterTests {
@@ -281,29 +236,20 @@ public class ErrorPageFilterTests {
@Test
public void exceptionErrorWithCommittedResponse ( ) throws Exception {
this . filter . addErrorPages ( new ErrorPage ( RuntimeException . class , "/500" ) ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
super . doFilter ( request , response ) ;
response . flushBuffer ( ) ;
throw new RuntimeException ( "BAD" ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
chain . call ( ) ;
response . flushBuffer ( ) ;
throw new RuntimeException ( "BAD" ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . response . getForwardedUrl ( ) ) . isNull ( ) ;
}
@Test
public void statusCode ( ) throws Exception {
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
assertThat ( ( ( HttpServletResponse ) response ) . getStatus ( ) ) . isEqualTo ( 200 ) ;
super . doFilter ( request , response ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
assertThat ( response . getStatus ( ) ) . isEqualTo ( 200 ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getStatus ( ) )
. isEqualTo ( 200 ) ;
@ -312,14 +258,10 @@ public class ErrorPageFilterTests {
@@ -312,14 +258,10 @@ public class ErrorPageFilterTests {
@Test
public void subClassExceptionError ( ) throws Exception {
this . filter . addErrorPages ( new ErrorPage ( RuntimeException . class , "/500" ) ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
super . doFilter ( request , response ) ;
throw new IllegalStateException ( "BAD" ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
chain . call ( ) ;
throw new IllegalStateException ( "BAD" ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getStatus ( ) )
. isEqualTo ( 500 ) ;
@ -355,14 +297,10 @@ public class ErrorPageFilterTests {
@@ -355,14 +297,10 @@ public class ErrorPageFilterTests {
throws Exception {
this . filter . addErrorPages ( new ErrorPage ( "/error" ) ) ;
this . request . setAsyncStarted ( true ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
super . doFilter ( request , response ) ;
throw new RuntimeException ( "BAD" ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
chain . call ( ) ;
throw new RuntimeException ( "BAD" ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . chain . getRequest ( ) ) . isEqualTo ( this . request ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getResponse ( ) )
@ -375,14 +313,10 @@ public class ErrorPageFilterTests {
@@ -375,14 +313,10 @@ public class ErrorPageFilterTests {
throws Exception {
this . filter . addErrorPages ( new ErrorPage ( "/error" ) ) ;
this . request . setAsyncStarted ( true ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
super . doFilter ( request , response ) ;
( ( HttpServletResponse ) response ) . sendError ( 400 , "BAD" ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
chain . call ( ) ;
response . sendError ( 400 , "BAD" ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . chain . getRequest ( ) ) . isEqualTo ( this . request ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getResponse ( ) )
@ -405,14 +339,10 @@ public class ErrorPageFilterTests {
@@ -405,14 +339,10 @@ public class ErrorPageFilterTests {
throws Exception {
this . filter . addErrorPages ( new ErrorPage ( "/error" ) ) ;
setUpAsyncDispatch ( ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
super . doFilter ( request , response ) ;
throw new RuntimeException ( "BAD" ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
chain . call ( ) ;
throw new RuntimeException ( "BAD" ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . chain . getRequest ( ) ) . isEqualTo ( this . request ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getResponse ( ) )
@ -425,14 +355,10 @@ public class ErrorPageFilterTests {
@@ -425,14 +355,10 @@ public class ErrorPageFilterTests {
throws Exception {
this . filter . addErrorPages ( new ErrorPage ( "/error" ) ) ;
setUpAsyncDispatch ( ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
super . doFilter ( request , response ) ;
( ( HttpServletResponse ) response ) . sendError ( 400 , "BAD" ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
chain . call ( ) ;
response . sendError ( 400 , "BAD" ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . chain . getRequest ( ) ) . isEqualTo ( this . request ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getResponse ( ) )
@ -455,16 +381,10 @@ public class ErrorPageFilterTests {
@@ -455,16 +381,10 @@ public class ErrorPageFilterTests {
throws IOException , ServletException {
this . request . setServletPath ( "/test" ) ;
this . filter . addErrorPages ( new ErrorPage ( "/error" ) ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
super . doFilter ( request , response ) ;
throw new RuntimeException ( ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
chain . call ( ) ;
throw new RuntimeException ( ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . output . toString ( ) ) . contains ( "request [/test]" ) ;
}
@ -475,16 +395,10 @@ public class ErrorPageFilterTests {
@@ -475,16 +395,10 @@ public class ErrorPageFilterTests {
this . request . setServletPath ( "/test" ) ;
this . request . setPathInfo ( "/alpha" ) ;
this . filter . addErrorPages ( new ErrorPage ( "/error" ) ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
super . doFilter ( request , response ) ;
throw new RuntimeException ( ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
chain . call ( ) ;
throw new RuntimeException ( ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . output . toString ( ) ) . contains ( "request [/test/alpha]" ) ;
}
@ -492,14 +406,10 @@ public class ErrorPageFilterTests {
@@ -492,14 +406,10 @@ public class ErrorPageFilterTests {
@Test
public void nestedServletExceptionIsUnwrapped ( ) throws Exception {
this . filter . addErrorPages ( new ErrorPage ( RuntimeException . class , "/500" ) ) ;
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
super . doFilter ( request , response ) ;
throw new NestedServletException ( "Wrapper" , new RuntimeException ( "BAD" ) ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
chain . call ( ) ;
throw new NestedServletException ( "Wrapper" , new RuntimeException ( "BAD" ) ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( ( ( HttpServletResponseWrapper ) this . chain . getResponse ( ) ) . getStatus ( ) )
. isEqualTo ( 500 ) ;
@ -524,15 +434,10 @@ public class ErrorPageFilterTests {
@@ -524,15 +434,10 @@ public class ErrorPageFilterTests {
@Test
public void whenErrorIsSentAndWriterIsFlushedErrorIsSentToTheClient ( )
throws Exception {
this . chain = new MockFilterChain ( ) {
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
( ( HttpServletResponse ) response ) . sendError ( 400 ) ;
response . getWriter ( ) . flush ( ) ;
super . doFilter ( request , response ) ;
}
} ;
this . chain = new TestFilterChain ( ( request , response , chain ) - > {
response . sendError ( 400 ) ;
response . getWriter ( ) . flush ( ) ;
} ) ;
this . filter . doFilter ( this . request , this . response , this . chain ) ;
assertThat ( this . response . getStatus ( ) ) . isEqualTo ( 400 ) ;
}
@ -551,6 +456,45 @@ public class ErrorPageFilterTests {
@@ -551,6 +456,45 @@ public class ErrorPageFilterTests {
return this . request . getDispatcher ( path ) . getRequestAttributes ( ) ;
}
private static class TestFilterChain extends MockFilterChain {
private final FilterHandler handler ;
TestFilterChain ( FilterHandler handler ) {
this . handler = handler ;
}
@Override
public void doFilter ( ServletRequest request , ServletResponse response )
throws IOException , ServletException {
AtomicBoolean called = new AtomicBoolean ( ) ;
Chain chain = ( ) - > {
if ( called . compareAndSet ( false , true ) ) {
super . doFilter ( request , response ) ;
}
} ;
this . handler . handle ( ( HttpServletRequest ) request ,
( HttpServletResponse ) response , chain ) ;
chain . call ( ) ;
}
}
@FunctionalInterface
private interface FilterHandler {
void handle ( HttpServletRequest request , HttpServletResponse response , Chain chain )
throws IOException , ServletException ;
}
@FunctionalInterface
private interface Chain {
void call ( ) throws IOException , ServletException ;
}
private static final class DispatchRecordingMockHttpServletRequest
extends MockHttpServletRequest {