@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2018 the original author or authors .
* Copyright 2002 - 2019 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -19,6 +19,8 @@ package org.springframework.security.oauth2.client.web.reactive.function.client;
@@ -19,6 +19,8 @@ package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.junit.Before ;
import org.junit.Test ;
import org.junit.runner.RunWith ;
import org.mockito.ArgumentCaptor ;
import org.mockito.Captor ;
import org.mockito.Mock ;
import org.mockito.junit.MockitoJUnitRunner ;
import org.springframework.core.codec.ByteBufferEncoder ;
@ -89,6 +91,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -89,6 +91,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ServerWebExchange serverWebExchange ;
@Captor
private ArgumentCaptor < OAuth2AuthorizedClient > authorizedClientCaptor ;
private ServerOAuth2AuthorizedClientExchangeFilterFunction function ;
private MockExchangeFunction exchange = new MockExchangeFunction ( ) ;
@ -173,7 +178,62 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -173,7 +178,62 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
. subscriberContext ( ReactiveSecurityContextHolder . withAuthentication ( authentication ) )
. block ( ) ;
verify ( this . authorizedClientRepository ) . saveAuthorizedClient ( any ( ) , eq ( authentication ) , any ( ) ) ;
verify ( this . authorizedClientRepository ) . saveAuthorizedClient ( this . authorizedClientCaptor . capture ( ) , eq ( authentication ) , any ( ) ) ;
OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor . getValue ( ) ;
assertThat ( newAuthorizedClient . getAccessToken ( ) ) . isEqualTo ( response . getAccessToken ( ) ) ;
assertThat ( newAuthorizedClient . getRefreshToken ( ) ) . isEqualTo ( response . getRefreshToken ( ) ) ;
List < ClientRequest > requests = this . exchange . getRequests ( ) ;
assertThat ( requests ) . hasSize ( 2 ) ;
ClientRequest request0 = requests . get ( 0 ) ;
assertThat ( request0 . headers ( ) . getFirst ( HttpHeaders . AUTHORIZATION ) ) . isEqualTo ( "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=" ) ;
assertThat ( request0 . url ( ) . toASCIIString ( ) ) . isEqualTo ( "https://example.com/login/oauth/access_token" ) ;
assertThat ( request0 . method ( ) ) . isEqualTo ( HttpMethod . POST ) ;
assertThat ( getBody ( request0 ) ) . isEqualTo ( "grant_type=refresh_token&refresh_token=refresh-token" ) ;
ClientRequest request1 = requests . get ( 1 ) ;
assertThat ( request1 . headers ( ) . getFirst ( HttpHeaders . AUTHORIZATION ) ) . isEqualTo ( "Bearer token-1" ) ;
assertThat ( request1 . url ( ) . toASCIIString ( ) ) . isEqualTo ( "https://example.com" ) ;
assertThat ( request1 . method ( ) ) . isEqualTo ( HttpMethod . GET ) ;
assertThat ( getBody ( request1 ) ) . isEmpty ( ) ;
}
@Test
public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefreshToken ( ) {
when ( this . authorizedClientRepository . saveAuthorizedClient ( any ( ) , any ( ) , any ( ) ) ) . thenReturn ( Mono . empty ( ) ) ;
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse . withToken ( "token-1" )
. tokenType ( OAuth2AccessToken . TokenType . BEARER )
. expiresIn ( 3600 )
// .refreshToken(xxx) // No refreshToken in response
. build ( ) ;
when ( this . exchange . getResponse ( ) . body ( any ( ) ) ) . thenReturn ( Mono . just ( response ) ) ;
Instant issuedAt = Instant . now ( ) . minus ( Duration . ofDays ( 1 ) ) ;
Instant accessTokenExpiresAt = issuedAt . plus ( Duration . ofHours ( 1 ) ) ;
this . accessToken = new OAuth2AccessToken ( this . accessToken . getTokenType ( ) ,
this . accessToken . getTokenValue ( ) ,
issuedAt ,
accessTokenExpiresAt ) ;
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken ( "refresh-token" , issuedAt ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . registration ,
"principalName" , this . accessToken , refreshToken ) ;
ClientRequest request = ClientRequest . create ( GET , URI . create ( "https://example.com" ) )
. attributes ( oauth2AuthorizedClient ( authorizedClient ) )
. build ( ) ;
TestingAuthenticationToken authentication = new TestingAuthenticationToken ( "test" , "this" ) ;
this . function . filter ( request , this . exchange )
. subscriberContext ( ReactiveSecurityContextHolder . withAuthentication ( authentication ) )
. block ( ) ;
verify ( this . authorizedClientRepository ) . saveAuthorizedClient ( this . authorizedClientCaptor . capture ( ) , eq ( authentication ) , any ( ) ) ;
OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor . getValue ( ) ;
assertThat ( newAuthorizedClient . getAccessToken ( ) ) . isEqualTo ( response . getAccessToken ( ) ) ;
assertThat ( newAuthorizedClient . getRefreshToken ( ) ) . isEqualTo ( authorizedClient . getRefreshToken ( ) ) ;
List < ClientRequest > requests = this . exchange . getRequests ( ) ;
assertThat ( requests ) . hasSize ( 2 ) ;