@ -46,12 +46,16 @@ import org.springframework.security.core.authority.AuthorityUtils;
@@ -46,12 +46,16 @@ import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder ;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient ;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken ;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient ;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository ;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations ;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository ;
import org.springframework.security.oauth2.core.OAuth2AccessToken ;
import org.springframework.security.oauth2.core.OAuth2RefreshToken ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse ;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses ;
import org.springframework.security.oauth2.core.user.OAuth2User ;
import org.springframework.web.context.request.RequestContextHolder ;
import org.springframework.web.context.request.ServletRequestAttributes ;
@ -89,6 +93,10 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -89,6 +93,10 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private OAuth2AuthorizedClientRepository authorizedClientRepository ;
@Mock
private ClientRegistrationRepository clientRegistrationRepository ;
@Mock
private OAuth2AccessTokenResponseClient < OAuth2ClientCredentialsGrantRequest > clientCredentialsTokenResponseClient ;
@Mock
private WebClient . RequestHeadersSpec < ? > spec ;
@Captor
private ArgumentCaptor < Consumer < Map < String , Object > > > attrs ;
@ -148,7 +156,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -148,7 +156,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet ( ) {
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
SecurityContextHolder . getContext ( ) . setAuthentication ( this . authentication ) ;
Map < String , Object > attrs = getDefaultRequestAttributes ( ) ;
assertThat ( getAuthentication ( attrs ) ) . isEqualTo ( this . authentication ) ;
@ -157,7 +166,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -157,7 +166,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride ( ) {
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . registration ,
"principalName" , this . accessToken ) ;
oauth2AuthorizedClient ( authorizedClient ) . accept ( this . result ) ;
@ -168,7 +178,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -168,7 +178,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull ( ) {
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
Map < String , Object > attrs = getDefaultRequestAttributes ( ) ;
assertThat ( getOAuth2AuthorizedClient ( attrs ) ) . isNull ( ) ;
verifyZeroInteractions ( this . authorizedClientRepository ) ;
@ -176,7 +187,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -176,7 +187,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull ( ) {
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
Map < String , Object > attrs = getDefaultRequestAttributes ( ) ;
assertThat ( getOAuth2AuthorizedClient ( attrs ) ) . isNull ( ) ;
verifyZeroInteractions ( this . authorizedClientRepository ) ;
@ -196,7 +208,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -196,7 +208,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient ( ) {
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
this . function . setDefaultOAuth2AuthorizedClient ( true ) ;
OAuth2User user = mock ( OAuth2User . class ) ;
List < GrantedAuthority > authorities = AuthorityUtils . createAuthorityList ( "ROLE_USER" ) ;
@ -214,7 +227,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -214,7 +227,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient ( ) {
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
OAuth2User user = mock ( OAuth2User . class ) ;
List < GrantedAuthority > authorities = AuthorityUtils . createAuthorityList ( "ROLE_USER" ) ;
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken ( user , authorities , "id" ) ;
@ -227,7 +241,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -227,7 +241,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit ( ) {
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
OAuth2User user = mock ( OAuth2User . class ) ;
List < GrantedAuthority > authorities = AuthorityUtils . createAuthorityList ( "ROLE_USER" ) ;
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken ( user , authorities , "id" ) ;
@ -245,9 +260,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -245,9 +260,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient ( ) {
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
OAuth2User user = mock ( OAuth2User . class ) ;
List < GrantedAuthority > authorities = AuthorityUtils . createAuthorityList ( "ROLE_USER" ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . registration ,
"principalName" , this . accessToken ) ;
when ( this . authorizedClientRepository . loadAuthorizedClient ( any ( ) , any ( ) , any ( ) ) ) . thenReturn ( authorizedClient ) ;
@ -259,6 +273,41 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -259,6 +273,41 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
verify ( this . authorizedClientRepository ) . loadAuthorizedClient ( eq ( "id" ) , any ( ) , any ( ) ) ;
}
@Test
public void defaultRequestWhenClientCredentialsThenAuthorizedClient ( ) {
this . registration = TestClientRegistrations . clientCredentials ( ) . build ( ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
this . function . setClientCredentialsTokenResponseClient ( this . clientCredentialsTokenResponseClient ) ;
when ( this . clientRegistrationRepository . findByRegistrationId ( any ( ) ) ) . thenReturn ( this . registration ) ;
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
. accessTokenResponse ( ) . build ( ) ;
when ( this . clientCredentialsTokenResponseClient . getTokenResponse ( any ( ) ) ) . thenReturn (
accessTokenResponse ) ;
clientRegistrationId ( this . registration . getRegistrationId ( ) ) . accept ( this . result ) ;
Map < String , Object > attrs = getDefaultRequestAttributes ( ) ;
OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient ( attrs ) ;
assertThat ( authorizedClient . getAccessToken ( ) ) . isEqualTo ( accessTokenResponse . getAccessToken ( ) ) ;
assertThat ( authorizedClient . getClientRegistration ( ) ) . isEqualTo ( this . registration ) ;
assertThat ( authorizedClient . getPrincipalName ( ) ) . isEqualTo ( "anonymousUser" ) ;
assertThat ( authorizedClient . getRefreshToken ( ) ) . isEqualTo ( accessTokenResponse . getRefreshToken ( ) ) ;
}
@Test
public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException ( ) {
this . registration = TestClientRegistrations . clientCredentials ( ) . build ( ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
clientRegistrationId ( this . registration . getRegistrationId ( ) ) . accept ( this . result ) ;
assertThatCode ( ( ) - > getDefaultRequestAttributes ( ) )
. isInstanceOf ( IllegalArgumentException . class ) ;
}
private Map < String , Object > getDefaultRequestAttributes ( ) {
this . function . defaultRequest ( ) . accept ( this . spec ) ;
verify ( this . spec ) . attributes ( this . attrs . capture ( ) ) ;
@ -322,7 +371,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -322,7 +371,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
this . accessToken . getTokenValue ( ) ,
issuedAt ,
accessTokenExpiresAt ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken ( "refresh-token" , issuedAt , refreshTokenExpiresAt ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . registration ,
@ -368,7 +418,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -368,7 +418,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
this . accessToken . getTokenValue ( ) ,
issuedAt ,
accessTokenExpiresAt ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken ( "refresh-token" , issuedAt , refreshTokenExpiresAt ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . registration ,
@ -400,7 +451,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -400,7 +451,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenRefreshTokenNullThenShouldRefreshFalse ( ) {
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . registration ,
"principalName" , this . accessToken ) ;
@ -422,7 +474,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -422,7 +474,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenNotExpiredThenShouldRefreshFalse ( ) {
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientRepository ) ;
this . function = new ServletOAuth2AuthorizedClientExchangeFilterFunction ( this . clientRegistrationRepository ,
this . authorizedClientRepository ) ;
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken ( "refresh-token" , this . accessToken . getIssuedAt ( ) , this . accessToken . getExpiresAt ( ) ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . registration ,