@ -18,15 +18,18 @@ package org.springframework.security.oauth2.client;
@@ -18,15 +18,18 @@ package org.springframework.security.oauth2.client;
import java.util.Collections ;
import java.util.Map ;
import java.util.function.Consumer ;
import org.junit.jupiter.api.Test ;
import org.springframework.security.authentication.TestingAuthenticationToken ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository ;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository ;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations ;
import org.springframework.security.oauth2.core.OAuth2AccessToken ;
import org.springframework.security.oauth2.core.OAuth2RefreshToken ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException ;
@ -126,7 +129,7 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
@@ -126,7 +129,7 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
this . authorizedClientService . saveAuthorizedClient ( authorizedClient , authentication ) ;
OAuth2AuthorizedClient loadedAuthorizedClient = this . authorizedClientService
. loadAuthorizedClient ( this . registration1 . getRegistrationId ( ) , this . principalName1 ) ;
assertAuthorizedClientEquals ( authorizedClient , loadedAuthorizedClient ) ;
assertThat ( loadedAuthorizedClient ) . satisfies ( isEqualTo ( authorizedClient ) ) ;
}
@Test
@ -134,27 +137,27 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
@@ -134,27 +137,27 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
ClientRegistration updatedRegistration = ClientRegistration . withClientRegistration ( this . registration1 )
. clientSecret ( "updated secret" )
. build ( ) ;
ClientRegistrationRepository repository = mock ( ClientRegistrationRepository . class ) ;
given ( repository . findByRegistrationId ( this . registration1 . getRegistrationId ( ) ) ) . willReturn ( this . registration1 ,
updatedRegistration ) ;
Authentication authentication = mock ( Authentication . class ) ;
given ( authentication . getName ( ) ) . willReturn ( this . principalName1 ) ;
ClientRegistrationRepository clientRegistrationRepository = mock ( ClientRegistrationRepository . class ) ;
given ( clientRegistrationRepository . findByRegistrationId ( this . registration1 . getRegistrationId ( ) ) )
. willReturn ( this . registration1 , updatedRegistration ) ;
InMemoryOAuth2AuthorizedClientService service = new InMemoryOAuth2AuthorizedClientService ( repository ) ;
InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService (
clientRegistrationRepository ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . registration1 , this . principalName1 ,
mock ( OAuth2AccessToken . class ) ) ;
service . saveAuthorizedClient ( authorizedClient , authentication ) ;
OAuth2AuthorizedClient cachedAuthorizedClient = new OAuth2AuthorizedClient ( this . registration1 ,
this . principalName1 , mock ( OAuth2AccessToken . class ) , mock ( OAuth2RefreshToken . class ) ) ;
authorizedClientService . saveAuthorizedClient ( cachedAuthorizedClient ,
new TestingAuthenticationToken ( this . principalName1 , null ) ) ;
OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient ( updatedRegistration ,
this . principalName1 , mock ( OAuth2AccessToken . class ) ) ;
OAuth2AuthorizedClient firstLoadedClient = service . loadAuthorizedClient ( this . registration1 . getRegistrationId ( ) ,
this . principalName1 ) ;
OAuth2AuthorizedClient secondLoadedClient = service . loadAuthorizedClient ( this . registration1 . getRegistrationId ( ) ,
this . principalName1 ) ;
assertAuthorizedClientEquals ( authorizedClient , firstLoadedClient ) ;
assertAuthorizedClientEquals ( authorizedClientWithUpdatedRegistration , secondLoadedClient ) ;
this . principalName1 , mock ( OAuth2AccessToken . class ) , mock ( OAuth2RefreshToken . class ) ) ;
OAuth2AuthorizedClient firstLoadedClient = authorizedClientService
. loadAuthorizedClient ( this . registration1 . getRegistrationId ( ) , this . principalName1 ) ;
OAuth2AuthorizedClient secondLoadedClient = authorizedClientService
. loadAuthorizedClient ( this . registration1 . getRegistrationId ( ) , this . principalName1 ) ;
assertThat ( firstLoadedClient ) . satisfies ( isEqualTo ( cachedAuthorizedClient ) ) ;
assertThat ( secondLoadedClient ) . satisfies ( isEqualTo ( authorizedClientWithUpdatedRegistration ) ) ;
}
@Test
@ -178,7 +181,7 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
@@ -178,7 +181,7 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
this . authorizedClientService . saveAuthorizedClient ( authorizedClient , authentication ) ;
OAuth2AuthorizedClient loadedAuthorizedClient = this . authorizedClientService
. loadAuthorizedClient ( this . registration3 . getRegistrationId ( ) , this . principalName2 ) ;
assertAuthorizedClientEquals ( authorizedClient , loadedAuthorizedClient ) ;
assertThat ( loadedAuthorizedClient ) . satisfies ( isEqualTo ( authorizedClient ) ) ;
}
@Test
@ -210,29 +213,38 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
@@ -210,29 +213,38 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
assertThat ( loadedAuthorizedClient ) . isNull ( ) ;
}
private static void assertAuthorizedClientEquals ( OAuth2AuthorizedClient expected , OAuth2AuthorizedClient actual ) {
assertThat ( actual ) . isNotNull ( ) ;
assertThat ( actual . getClientRegistration ( ) . getRegistrationId ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getRegistrationId ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getClientName ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getClientName ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getRedirectUri ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getRedirectUri ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getAuthorizationGrantType ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getAuthorizationGrantType ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getClientAuthenticationMethod ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getClientAuthenticationMethod ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getClientId ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getClientId ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getClientSecret ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getClientSecret ( ) ) ;
assertThat ( actual . getPrincipalName ( ) ) . isEqualTo ( expected . getPrincipalName ( ) ) ;
assertThat ( actual . getAccessToken ( ) . getTokenType ( ) ) . isEqualTo ( expected . getAccessToken ( ) . getTokenType ( ) ) ;
assertThat ( actual . getAccessToken ( ) . getTokenValue ( ) ) . isEqualTo ( expected . getAccessToken ( ) . getTokenValue ( ) ) ;
assertThat ( actual . getAccessToken ( ) . getIssuedAt ( ) ) . isEqualTo ( expected . getAccessToken ( ) . getIssuedAt ( ) ) ;
assertThat ( actual . getAccessToken ( ) . getExpiresAt ( ) ) . isEqualTo ( expected . getAccessToken ( ) . getExpiresAt ( ) ) ;
assertThat ( actual . getAccessToken ( ) . getScopes ( ) ) . isEqualTo ( expected . getAccessToken ( ) . getScopes ( ) ) ;
assertThat ( actual . getRefreshToken ( ) ) . isEqualTo ( expected . getRefreshToken ( ) ) ;
private static Consumer < OAuth2AuthorizedClient > isEqualTo ( OAuth2AuthorizedClient expected ) {
return ( actual ) - > {
assertThat ( actual ) . isNotNull ( ) ;
assertThat ( actual . getClientRegistration ( ) . getRegistrationId ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getRegistrationId ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getClientName ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getClientName ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getRedirectUri ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getRedirectUri ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getAuthorizationGrantType ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getAuthorizationGrantType ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getClientAuthenticationMethod ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getClientAuthenticationMethod ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getClientId ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getClientId ( ) ) ;
assertThat ( actual . getClientRegistration ( ) . getClientSecret ( ) )
. isEqualTo ( expected . getClientRegistration ( ) . getClientSecret ( ) ) ;
assertThat ( actual . getPrincipalName ( ) ) . isEqualTo ( expected . getPrincipalName ( ) ) ;
assertThat ( actual . getAccessToken ( ) . getTokenType ( ) ) . isEqualTo ( expected . getAccessToken ( ) . getTokenType ( ) ) ;
assertThat ( actual . getAccessToken ( ) . getTokenValue ( ) ) . isEqualTo ( expected . getAccessToken ( ) . getTokenValue ( ) ) ;
assertThat ( actual . getAccessToken ( ) . getIssuedAt ( ) ) . isEqualTo ( expected . getAccessToken ( ) . getIssuedAt ( ) ) ;
assertThat ( actual . getAccessToken ( ) . getExpiresAt ( ) ) . isEqualTo ( expected . getAccessToken ( ) . getExpiresAt ( ) ) ;
assertThat ( actual . getAccessToken ( ) . getScopes ( ) ) . isEqualTo ( expected . getAccessToken ( ) . getScopes ( ) ) ;
if ( expected . getRefreshToken ( ) ! = null ) {
assertThat ( actual . getRefreshToken ( ) ) . isNotNull ( ) ;
assertThat ( actual . getRefreshToken ( ) . getTokenValue ( ) )
. isEqualTo ( expected . getRefreshToken ( ) . getTokenValue ( ) ) ;
assertThat ( actual . getRefreshToken ( ) . getIssuedAt ( ) ) . isEqualTo ( expected . getRefreshToken ( ) . getIssuedAt ( ) ) ;
assertThat ( actual . getRefreshToken ( ) . getExpiresAt ( ) )
. isEqualTo ( expected . getRefreshToken ( ) . getExpiresAt ( ) ) ;
}
} ;
}
}