@ -50,6 +50,7 @@ import static org.assertj.core.api.Assertions.assertThat;
@@ -50,6 +50,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.BDDMockito.given ;
import static org.mockito.BDDMockito.willAnswer ;
import static org.mockito.BDDMockito.willThrow ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.verify ;
@ -325,4 +326,22 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
@@ -325,4 +326,22 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
+ "redirect_uri=http://localhost/login/oauth2/code/registration-id" ) ;
}
// gh-11602
@Test
public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenSaveRequestBeforeCommitted ( )
throws Exception {
String requestUri = "/path" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
willAnswer ( ( invocation ) - > assertThat ( ( invocation . < HttpServletResponse > getArgument ( 1 ) ) . isCommitted ( ) ) . isFalse ( ) )
. given ( this . requestCache ) . saveRequest ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
willThrow ( new ClientAuthorizationRequiredException ( this . registration1 . getRegistrationId ( ) ) ) . given ( filterChain )
. doFilter ( any ( ServletRequest . class ) , any ( ServletResponse . class ) ) ;
this . filter . doFilter ( request , response , filterChain ) ;
assertThat ( response . isCommitted ( ) ) . isTrue ( ) ;
}
}