@ -13,16 +13,20 @@
@@ -13,16 +13,20 @@
* See the License for the specific language governing permissions and
* limitations under the License .
* /
package org.springframework.security.oauth2.client.authentication ;
package org.springframework.security.oauth2.client.web ;
import org.assertj.core.api.Assertions ;
import org.junit.Test ;
import org.mockito.ArgumentCaptor ;
import org.mockito.Matchers ;
import org.mockito.Mockito ;
import org.springframework.mock.web.MockHttpServletRequest ;
import org.springframework.mock.web.MockHttpServletResponse ;
import org.springframework.security.authentication.AuthenticationManager ;
import org.springframework.security.authentication.TestingAuthenticationToken ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.AuthenticationException ;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository ;
import org.springframework.security.oauth2.core.OAuth2Error ;
@ -38,7 +42,6 @@ import javax.servlet.http.HttpServletResponse;
@@ -38,7 +42,6 @@ import javax.servlet.http.HttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.mockito.Matchers.any ;
import static org.mockito.Mockito.* ;
import static org.springframework.security.oauth2.client.authentication.TestUtil.* ;
/ * *
* Tests { @link AuthorizationCodeAuthenticationProcessingFilter } .
@ -49,28 +52,28 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -49,28 +52,28 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@Test
public void doFilterWhenNotAuthorizationCodeResponseThenContinueChain ( ) throws Exception {
ClientRegistration clientRegistration = googleClientRegistration ( ) ;
ClientRegistration clientRegistration = TestUtil . googleClientRegistration ( ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = spy ( setupFilter ( clientRegistration ) ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = Mockito . spy ( setupFilter ( clientRegistration ) ) ;
String requestURI = "/path" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestURI ) ;
request . setServletPath ( requestURI ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
FilterChain filterChain = Mockito . mock ( FilterChain . class ) ;
filter . doFilter ( request , response , filterChain ) ;
verify ( filterChain ) . doFilter ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
verify ( filter , never ( ) ) . attemptAuthentication ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
Mockito . verify ( filterChain ) . doFilter ( Matchers . any ( HttpServletRequest . class ) , Matchers . any ( HttpServletResponse . class ) ) ;
Mockito . verify ( filter , Mockito . never ( ) ) . attemptAuthentication ( Matchers . any ( HttpServletRequest . class ) , Matchers . any ( HttpServletResponse . class ) ) ;
}
@Test
public void doFilterWhenAuthorizationCodeErrorResponseThenAuthenticationFailureHandlerIsCalled ( ) throws Exception {
ClientRegistration clientRegistration = githubClientRegistration ( ) ;
ClientRegistration clientRegistration = TestUtil . githubClientRegistration ( ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = spy ( setupFilter ( clientRegistration ) ) ;
AuthenticationFailureHandler failureHandler = mock ( AuthenticationFailureHandler . class ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = Mockito . spy ( setupFilter ( clientRegistration ) ) ;
AuthenticationFailureHandler failureHandler = Mockito . mock ( AuthenticationFailureHandler . class ) ;
filter . setAuthenticationFailureHandler ( failureHandler ) ;
MockHttpServletRequest request = this . setupRequest ( clientRegistration ) ;
@ -78,25 +81,25 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -78,25 +81,25 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
request . addParameter ( OAuth2Parameter . ERROR , errorCode ) ;
request . addParameter ( OAuth2Parameter . STATE , "some state" ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
FilterChain filterChain = Mockito . mock ( FilterChain . class ) ;
filter . doFilter ( request , response , filterChain ) ;
verify ( filter ) . attemptAuthentication ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
verify ( failureHandler ) . onAuthenticationFailure ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ,
any ( AuthenticationException . class ) ) ;
Mockito . verify ( filter ) . attemptAuthentication ( Matchers . any ( HttpServletRequest . class ) , Matchers . any ( HttpServletResponse . class ) ) ;
Mockito . verify ( failureHandler ) . onAuthenticationFailure ( Matchers . any ( HttpServletRequest . class ) , Matchers . any ( HttpServletResponse . class ) ,
Matchers . any ( AuthenticationException . class ) ) ;
}
@Test
public void doFilterWhenAuthorizationCodeSuccessResponseThenAuthenticationSuccessHandlerIsCalled ( ) throws Exception {
TestingAuthenticationToken authentication = new TestingAuthenticationToken ( "joe" , "password" , "user" , "admin" ) ;
AuthenticationManager authenticationManager = mock ( AuthenticationManager . class ) ;
when ( authenticationManager . authenticate ( any ( Authentication . class ) ) ) . thenReturn ( authentication ) ;
AuthenticationManager authenticationManager = Mockito . mock ( AuthenticationManager . class ) ;
Mockito . when ( authenticationManager . authenticate ( Matchers . any ( Authentication . class ) ) ) . thenReturn ( authentication ) ;
ClientRegistration clientRegistration = githubClientRegistration ( ) ;
ClientRegistration clientRegistration = TestUtil . githubClientRegistration ( ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = spy ( setupFilter ( authenticationManager , clientRegistration ) ) ;
AuthenticationSuccessHandler successHandler = mock ( AuthenticationSuccessHandler . class ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = Mockito . spy ( setupFilter ( authenticationManager , clientRegistration ) ) ;
AuthenticationSuccessHandler successHandler = Mockito . mock ( AuthenticationSuccessHandler . class ) ;
filter . setAuthenticationSuccessHandler ( successHandler ) ;
AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository ( ) ;
filter . setAuthorizationRequestRepository ( authorizationRequestRepository ) ;
@ -108,24 +111,24 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -108,24 +111,24 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
request . addParameter ( OAuth2Parameter . STATE , state ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
setupAuthorizationRequest ( authorizationRequestRepository , request , response , clientRegistration , state ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
FilterChain filterChain = Mockito . mock ( FilterChain . class ) ;
filter . doFilter ( request , response , filterChain ) ;
verify ( filter ) . attemptAuthentication ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
Mockito . verify ( filter ) . attemptAuthentication ( Matchers . any ( HttpServletRequest . class ) , Matchers . any ( HttpServletResponse . class ) ) ;
ArgumentCaptor < Authentication > authenticationArgCaptor = ArgumentCaptor . forClass ( Authentication . class ) ;
verify ( successHandler ) . onAuthenticationSuccess ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ,
Mockito . verify ( successHandler ) . onAuthenticationSuccess ( Matchers . any ( HttpServletRequest . class ) , Matchers . any ( HttpServletResponse . class ) ,
authenticationArgCaptor . capture ( ) ) ;
assertThat ( authenticationArgCaptor . getValue ( ) ) . isEqualTo ( authentication ) ;
Assertions . assertThat ( authenticationArgCaptor . getValue ( ) ) . isEqualTo ( authentication ) ;
}
@Test
public void doFilterWhenAuthorizationCodeSuccessResponseAndNoMatchingAuthorizationRequestThenThrowOAuth2AuthenticationExceptionAuthorizationRequestNotFound ( ) throws Exception {
ClientRegistration clientRegistration = githubClientRegistration ( ) ;
ClientRegistration clientRegistration = TestUtil . githubClientRegistration ( ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = spy ( setupFilter ( clientRegistration ) ) ;
AuthenticationFailureHandler failureHandler = mock ( AuthenticationFailureHandler . class ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = Mockito . spy ( setupFilter ( clientRegistration ) ) ;
AuthenticationFailureHandler failureHandler = Mockito . mock ( AuthenticationFailureHandler . class ) ;
filter . setAuthenticationFailureHandler ( failureHandler ) ;
MockHttpServletRequest request = this . setupRequest ( clientRegistration ) ;
@ -134,7 +137,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -134,7 +137,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
request . addParameter ( OAuth2Parameter . CODE , authCode ) ;
request . addParameter ( OAuth2Parameter . STATE , state ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
FilterChain filterChain = Mockito . mock ( FilterChain . class ) ;
filter . doFilter ( request , response , filterChain ) ;
@ -143,10 +146,10 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -143,10 +146,10 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@Test
public void doFilterWhenAuthorizationCodeSuccessResponseWithInvalidStateParamThenThrowOAuth2AuthenticationExceptionInvalidStateParameter ( ) throws Exception {
ClientRegistration clientRegistration = githubClientRegistration ( ) ;
ClientRegistration clientRegistration = TestUtil . githubClientRegistration ( ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = spy ( setupFilter ( clientRegistration ) ) ;
AuthenticationFailureHandler failureHandler = mock ( AuthenticationFailureHandler . class ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = Mockito . spy ( setupFilter ( clientRegistration ) ) ;
AuthenticationFailureHandler failureHandler = Mockito . mock ( AuthenticationFailureHandler . class ) ;
filter . setAuthenticationFailureHandler ( failureHandler ) ;
AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository ( ) ;
filter . setAuthorizationRequestRepository ( authorizationRequestRepository ) ;
@ -158,7 +161,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -158,7 +161,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
request . addParameter ( OAuth2Parameter . STATE , state ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
setupAuthorizationRequest ( authorizationRequestRepository , request , response , clientRegistration , "some state" ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
FilterChain filterChain = Mockito . mock ( FilterChain . class ) ;
filter . doFilter ( request , response , filterChain ) ;
@ -167,10 +170,10 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -167,10 +170,10 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@Test
public void doFilterWhenAuthorizationCodeSuccessResponseWithInvalidRedirectUriParamThenThrowOAuth2AuthenticationExceptionInvalidRedirectUriParameter ( ) throws Exception {
ClientRegistration clientRegistration = githubClientRegistration ( ) ;
ClientRegistration clientRegistration = TestUtil . githubClientRegistration ( ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = spy ( setupFilter ( clientRegistration ) ) ;
AuthenticationFailureHandler failureHandler = mock ( AuthenticationFailureHandler . class ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = Mockito . spy ( setupFilter ( clientRegistration ) ) ;
AuthenticationFailureHandler failureHandler = Mockito . mock ( AuthenticationFailureHandler . class ) ;
filter . setAuthenticationFailureHandler ( failureHandler ) ;
AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository ( ) ;
filter . setAuthorizationRequestRepository ( authorizationRequestRepository ) ;
@ -183,7 +186,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -183,7 +186,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
request . addParameter ( OAuth2Parameter . STATE , state ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
setupAuthorizationRequest ( authorizationRequestRepository , request , response , clientRegistration , state ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
FilterChain filterChain = Mockito . mock ( FilterChain . class ) ;
filter . doFilter ( request , response , filterChain ) ;
@ -194,21 +197,21 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -194,21 +197,21 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
AuthenticationFailureHandler failureHandler ,
String errorCode ) throws Exception {
verify ( filter ) . attemptAuthentication ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
Mockito . verify ( filter ) . attemptAuthentication ( Matchers . any ( HttpServletRequest . class ) , Matchers . any ( HttpServletResponse . class ) ) ;
ArgumentCaptor < AuthenticationException > authenticationExceptionArgCaptor =
ArgumentCaptor . forClass ( AuthenticationException . class ) ;
verify ( failureHandler ) . onAuthenticationFailure ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ,
Mockito . verify ( failureHandler ) . onAuthenticationFailure ( Matchers . any ( HttpServletRequest . class ) , Matchers . any ( HttpServletResponse . class ) ,
authenticationExceptionArgCaptor . capture ( ) ) ;
assertThat ( authenticationExceptionArgCaptor . getValue ( ) ) . isInstanceOf ( OAuth2AuthenticationException . class ) ;
Assertions . assertThat ( authenticationExceptionArgCaptor . getValue ( ) ) . isInstanceOf ( OAuth2AuthenticationException . class ) ;
OAuth2AuthenticationException oauth2AuthenticationException =
( OAuth2AuthenticationException ) authenticationExceptionArgCaptor . getValue ( ) ;
assertThat ( oauth2AuthenticationException . getErrorObject ( ) ) . isNotNull ( ) ;
assertThat ( oauth2AuthenticationException . getErrorObject ( ) . getErrorCode ( ) ) . isEqualTo ( errorCode ) ;
Assertions . assertThat ( oauth2AuthenticationException . getErrorObject ( ) ) . isNotNull ( ) ;
Assertions . assertThat ( oauth2AuthenticationException . getErrorObject ( ) . getErrorCode ( ) ) . isEqualTo ( errorCode ) ;
}
private AuthorizationCodeAuthenticationProcessingFilter setupFilter ( ClientRegistration . . . clientRegistrations ) throws Exception {
AuthenticationManager authenticationManager = mock ( AuthenticationManager . class ) ;
AuthenticationManager authenticationManager = Mockito . mock ( AuthenticationManager . class ) ;
return setupFilter ( authenticationManager , clientRegistrations ) ;
}
@ -216,7 +219,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -216,7 +219,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
private AuthorizationCodeAuthenticationProcessingFilter setupFilter (
AuthenticationManager authenticationManager , ClientRegistration . . . clientRegistrations ) throws Exception {
ClientRegistrationRepository clientRegistrationRepository = clientRegistrationRepository ( clientRegistrations ) ;
ClientRegistrationRepository clientRegistrationRepository = TestUtil . clientRegistrationRepository ( clientRegistrations ) ;
AuthorizationCodeAuthenticationProcessingFilter filter = new AuthorizationCodeAuthenticationProcessingFilter ( ) ;
filter . setClientRegistrationRepository ( clientRegistrationRepository ) ;
@ -244,11 +247,11 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
@@ -244,11 +247,11 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
}
private MockHttpServletRequest setupRequest ( ClientRegistration clientRegistration ) {
String requestURI = AUTHORIZE_BASE_URI + "/" + clientRegistration . getClientAlias ( ) ;
String requestURI = TestUtil . AUTHORIZE_BASE_URI + "/" + clientRegistration . getClientAlias ( ) ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestURI ) ;
request . setScheme ( DEFAULT_SCHEME ) ;
request . setServerName ( DEFAULT_SERVER_NAME ) ;
request . setServerPort ( DEFAULT_SERVER_PORT ) ;
request . setScheme ( TestUtil . DEFAULT_SCHEME ) ;
request . setServerName ( TestUtil . DEFAULT_SERVER_NAME ) ;
request . setServerPort ( TestUtil . DEFAULT_SERVER_PORT ) ;
request . setServletPath ( requestURI ) ;
return request ;
}