@ -15,6 +15,23 @@
@@ -15,6 +15,23 @@
* /
package org.springframework.security.web.access ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatThrownBy ;
import static org.assertj.core.api.Assertions.fail ;
import static org.mockito.Matchers.any ;
import static org.mockito.Mockito.doThrow ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.verifyZeroInteractions ;
import java.io.IOException ;
import java.util.Locale ;
import javax.servlet.FilterChain ;
import javax.servlet.ServletException ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import javax.servlet.http.HttpSession ;
import org.junit.After ;
import org.junit.Before ;
import org.junit.Test ;
@ -36,20 +53,6 @@ import org.springframework.security.web.WebAttributes;
@@ -36,20 +53,6 @@ import org.springframework.security.web.WebAttributes;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache ;
import org.springframework.security.web.savedrequest.SavedRequest ;
import javax.servlet.FilterChain ;
import javax.servlet.ServletException ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import javax.servlet.http.HttpSession ;
import java.io.IOException ;
import java.util.Locale ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.fail ;
import static org.mockito.Matchers.any ;
import static org.mockito.Mockito.doThrow ;
import static org.mockito.Mockito.mock ;
/ * *
* Tests { @link ExceptionTranslationFilter } .
*
@ -302,7 +305,26 @@ public class ExceptionTranslationFilterTests {
@@ -302,7 +305,26 @@ public class ExceptionTranslationFilterTests {
}
}
private final AuthenticationEntryPoint mockEntryPoint = new AuthenticationEntryPoint ( ) {
@Test
public void doFilterWhenResponseCommittedThenRethrowsException ( ) throws Exception {
this . mockEntryPoint = mock ( AuthenticationEntryPoint . class ) ;
FilterChain chain = ( request , response ) - > {
HttpServletResponse httpResponse = ( HttpServletResponse ) response ;
httpResponse . sendError ( HttpServletResponse . SC_BAD_REQUEST ) ;
throw new AccessDeniedException ( "Denied" ) ;
} ;
MockHttpServletRequest request = new MockHttpServletRequest ( ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
ExceptionTranslationFilter filter = new ExceptionTranslationFilter ( mockEntryPoint ) ;
assertThatThrownBy ( ( ) - > filter . doFilter ( request , response , chain ) )
. isInstanceOf ( ServletException . class )
. hasCauseInstanceOf ( AccessDeniedException . class ) ;
verifyZeroInteractions ( mockEntryPoint ) ;
}
private AuthenticationEntryPoint mockEntryPoint = new AuthenticationEntryPoint ( ) {
public void commence ( HttpServletRequest request , HttpServletResponse response ,
AuthenticationException authException ) throws IOException ,
ServletException {