@ -19,7 +19,9 @@ import java.security.Principal;
@@ -19,7 +19,9 @@ import java.security.Principal;
import java.time.Instant ;
import java.time.temporal.ChronoUnit ;
import java.util.Collections ;
import java.util.HashMap ;
import java.util.HashSet ;
import java.util.Map ;
import java.util.Set ;
import org.junit.Before ;
@ -36,23 +38,28 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
@@ -36,23 +38,28 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2 ;
import org.springframework.security.oauth2.core.OAuth2TokenType ;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
import org.springframework.security.oauth2.core.oidc.OidcIdToken ;
import org.springframework.security.oauth2.core.oidc.OidcScopes ;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames ;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm ;
import org.springframework.security.oauth2.jwt.JoseHeaderNames ;
import org.springframework.security.oauth2.jwt.Jwt ;
import org.springframework.security.oauth2.jwt.JwtEncoder ;
import org.springframework.security.oauth2.server.authorization.JwtEncodingContext ;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization ;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService ;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer ;
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations ;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient ;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients ;
import org.springframework.security.oauth2.server.authorization.JwtEncodingContext ;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer ;
import static org.assertj.core.api.Assertions.entry ;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy ;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.eq ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.times ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.when ;
@ -61,6 +68,7 @@ import static org.mockito.Mockito.when;
@@ -61,6 +68,7 @@ import static org.mockito.Mockito.when;
*
* @author Alexey Nesterov
* @author Joe Grandja
* @author Anoop Garlapati
* @since 0 . 0 . 3
* /
public class OAuth2RefreshTokenAuthenticationProviderTests {
@ -156,6 +164,72 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
@@ -156,6 +164,72 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
assertThat ( updatedAuthorization . getRefreshToken ( ) ) . isEqualTo ( authorization . getRefreshToken ( ) ) ;
}
@Test
public void authenticateWhenValidRefreshTokenThenReturnIdToken ( ) {
RegisteredClient registeredClient = TestRegisteredClients . registeredClient ( ) . scope ( OidcScopes . OPENID ) . build ( ) ;
OAuth2Authorization authorization = TestOAuth2Authorizations . authorization ( registeredClient ) . build ( ) ;
when ( this . authorizationService . findByToken (
eq ( authorization . getRefreshToken ( ) . getToken ( ) . getTokenValue ( ) ) ,
eq ( OAuth2TokenType . REFRESH_TOKEN ) ) )
. thenReturn ( authorization ) ;
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken ( registeredClient ) ;
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken (
authorization . getRefreshToken ( ) . getToken ( ) . getTokenValue ( ) , clientPrincipal , null , null ) ;
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
( OAuth2AccessTokenAuthenticationToken ) this . authenticationProvider . authenticate ( authentication ) ;
ArgumentCaptor < JwtEncodingContext > jwtEncodingContextCaptor = ArgumentCaptor . forClass ( JwtEncodingContext . class ) ;
verify ( this . jwtCustomizer , times ( 2 ) ) . customize ( jwtEncodingContextCaptor . capture ( ) ) ;
// Access Token context
JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor . getAllValues ( ) . get ( 0 ) ;
assertThat ( accessTokenContext . getRegisteredClient ( ) ) . isEqualTo ( registeredClient ) ;
assertThat ( accessTokenContext . < Authentication > getPrincipal ( ) ) . isEqualTo ( authorization . getAttribute ( Principal . class . getName ( ) ) ) ;
assertThat ( accessTokenContext . getAuthorization ( ) ) . isEqualTo ( authorization ) ;
assertThat ( accessTokenContext . getAuthorizedScopes ( ) )
. isEqualTo ( authorization . getAttribute ( OAuth2Authorization . AUTHORIZED_SCOPE_ATTRIBUTE_NAME ) ) ;
assertThat ( accessTokenContext . getTokenType ( ) ) . isEqualTo ( OAuth2TokenType . ACCESS_TOKEN ) ;
assertThat ( accessTokenContext . getAuthorizationGrantType ( ) ) . isEqualTo ( AuthorizationGrantType . REFRESH_TOKEN ) ;
assertThat ( accessTokenContext . < OAuth2AuthorizationGrantAuthenticationToken > getAuthorizationGrant ( ) ) . isEqualTo ( authentication ) ;
assertThat ( accessTokenContext . getHeaders ( ) ) . isNotNull ( ) ;
assertThat ( accessTokenContext . getClaims ( ) ) . isNotNull ( ) ;
Map < String , Object > claims = new HashMap < > ( ) ;
accessTokenContext . getClaims ( ) . claims ( claims : : putAll ) ;
assertThat ( claims ) . flatExtracting ( OAuth2ParameterNames . SCOPE )
. containsExactlyInAnyOrder ( OidcScopes . OPENID , "scope1" ) ;
// ID Token context
JwtEncodingContext idTokenContext = jwtEncodingContextCaptor . getAllValues ( ) . get ( 1 ) ;
assertThat ( idTokenContext . getRegisteredClient ( ) ) . isEqualTo ( registeredClient ) ;
assertThat ( idTokenContext . < Authentication > getPrincipal ( ) ) . isEqualTo ( authorization . getAttribute ( Principal . class . getName ( ) ) ) ;
assertThat ( idTokenContext . getAuthorization ( ) ) . isEqualTo ( authorization ) ;
assertThat ( idTokenContext . getAuthorizedScopes ( ) )
. isEqualTo ( authorization . getAttribute ( OAuth2Authorization . AUTHORIZED_SCOPE_ATTRIBUTE_NAME ) ) ;
assertThat ( idTokenContext . getTokenType ( ) . getValue ( ) ) . isEqualTo ( OidcParameterNames . ID_TOKEN ) ;
assertThat ( idTokenContext . getAuthorizationGrantType ( ) ) . isEqualTo ( AuthorizationGrantType . REFRESH_TOKEN ) ;
assertThat ( idTokenContext . < OAuth2AuthorizationGrantAuthenticationToken > getAuthorizationGrant ( ) ) . isEqualTo ( authentication ) ;
assertThat ( idTokenContext . getHeaders ( ) ) . isNotNull ( ) ;
assertThat ( idTokenContext . getClaims ( ) ) . isNotNull ( ) ;
verify ( this . jwtEncoder , times ( 2 ) ) . encode ( any ( ) , any ( ) ) ; // Access token and ID Token
ArgumentCaptor < OAuth2Authorization > authorizationCaptor = ArgumentCaptor . forClass ( OAuth2Authorization . class ) ;
verify ( this . authorizationService ) . save ( authorizationCaptor . capture ( ) ) ;
OAuth2Authorization updatedAuthorization = authorizationCaptor . getValue ( ) ;
assertThat ( accessTokenAuthentication . getRegisteredClient ( ) . getId ( ) ) . isEqualTo ( updatedAuthorization . getRegisteredClientId ( ) ) ;
assertThat ( accessTokenAuthentication . getPrincipal ( ) ) . isEqualTo ( clientPrincipal ) ;
assertThat ( accessTokenAuthentication . getAccessToken ( ) ) . isEqualTo ( updatedAuthorization . getAccessToken ( ) . getToken ( ) ) ;
assertThat ( updatedAuthorization . getAccessToken ( ) ) . isNotEqualTo ( authorization . getAccessToken ( ) ) ;
OAuth2Authorization . Token < OidcIdToken > idToken = updatedAuthorization . getToken ( OidcIdToken . class ) ;
assertThat ( idToken ) . isNotNull ( ) ;
assertThat ( accessTokenAuthentication . getAdditionalParameters ( ) )
. containsExactly ( entry ( OidcParameterNames . ID_TOKEN , idToken . getToken ( ) . getTokenValue ( ) ) ) ;
assertThat ( accessTokenAuthentication . getRefreshToken ( ) ) . isEqualTo ( updatedAuthorization . getRefreshToken ( ) . getToken ( ) ) ;
// By default, refresh token is reused
assertThat ( updatedAuthorization . getRefreshToken ( ) ) . isEqualTo ( authorization . getRefreshToken ( ) ) ;
}
@Test
public void authenticateWhenReuseRefreshTokensFalseThenReturnNewRefreshToken ( ) {
RegisteredClient registeredClient = TestRegisteredClients . registeredClient ( )