@ -17,19 +17,51 @@
@@ -17,19 +17,51 @@
package org.springframework.security.oauth2.client.web.reactive.function.client ;
import org.junit.Test ;
import org.junit.runner.RunWith ;
import org.mockito.Mock ;
import org.mockito.junit.MockitoJUnitRunner ;
import org.springframework.core.codec.ByteBufferEncoder ;
import org.springframework.core.codec.CharSequenceEncoder ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpMethod ;
import org.springframework.http.codec.EncoderHttpMessageWriter ;
import org.springframework.http.codec.FormHttpMessageWriter ;
import org.springframework.http.codec.HttpMessageWriter ;
import org.springframework.http.codec.ResourceHttpMessageWriter ;
import org.springframework.http.codec.ServerSentEventHttpMessageWriter ;
import org.springframework.http.codec.json.Jackson2JsonEncoder ;
import org.springframework.http.codec.multipart.MultipartHttpMessageWriter ;
import org.springframework.http.codec.xml.Jaxb2XmlEncoder ;
import org.springframework.http.server.reactive.ServerHttpRequest ;
import org.springframework.mock.http.client.reactive.MockClientHttpRequest ;
import org.springframework.security.authentication.TestingAuthenticationToken ;
import org.springframework.security.core.context.ReactiveSecurityContextHolder ;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient ;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.core.AuthorizationGrantType ;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod ;
import org.springframework.security.oauth2.core.OAuth2AccessToken ;
import org.springframework.security.oauth2.core.OAuth2RefreshToken ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse ;
import org.springframework.web.reactive.function.BodyInserter ;
import org.springframework.web.reactive.function.client.ClientRequest ;
import reactor.core.publisher.Mono ;
import java.net.URI ;
import java.time.Duration ;
import java.time.Instant ;
import java.util.ArrayList ;
import java.util.HashMap ;
import java.util.List ;
import java.util.Map ;
import java.util.Optional ;
import static org.assertj.core.api.Assertions.* ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.eq ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.when ;
import static org.springframework.http.HttpMethod.GET ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient ;
@ -37,7 +69,11 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c
@@ -37,7 +69,11 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c
* @author Rob Winch
* @since 5 . 1
* /
@RunWith ( MockitoJUnitRunner . class )
public class OAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ReactiveOAuth2AuthorizedClientService authorizedClientService ;
private OAuth2AuthorizedClientExchangeFilterFunction function = new OAuth2AuthorizedClientExchangeFilterFunction ( ) ;
private MockExchangeFunction exchange = new MockExchangeFunction ( ) ;
@ -57,7 +93,7 @@ public class OAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -57,7 +93,7 @@ public class OAuth2AuthorizedClientExchangeFilterFunctionTests {
. build ( ) ;
private OAuth2AccessToken accessToken = new OAuth2AccessToken ( OAuth2AccessToken . TokenType . BEARER ,
"token" ,
"token-0 " ,
Instant . now ( ) ,
Instant . now ( ) . plus ( Duration . ofDays ( 1 ) ) ) ;
@ -98,4 +134,180 @@ public class OAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -98,4 +134,180 @@ public class OAuth2AuthorizedClientExchangeFilterFunctionTests {
HttpHeaders headers = this . exchange . getRequest ( ) . headers ( ) ;
assertThat ( headers . get ( HttpHeaders . AUTHORIZATION ) ) . containsOnly ( "Bearer " + this . accessToken . getTokenValue ( ) ) ;
}
@Test
public void filterWhenRefreshRequiredThenRefresh ( ) {
when ( this . authorizedClientService . saveAuthorizedClient ( any ( ) , any ( ) ) ) . thenReturn ( Mono . empty ( ) ) ;
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse . withToken ( "token-1" )
. tokenType ( OAuth2AccessToken . TokenType . BEARER )
. expiresIn ( 3600 )
. refreshToken ( "refresh-1" )
. build ( ) ;
when ( this . exchange . getResponse ( ) . body ( any ( ) ) ) . thenReturn ( Mono . just ( response ) ) ;
Instant issuedAt = Instant . now ( ) . minus ( Duration . ofDays ( 1 ) ) ;
Instant accessTokenExpiresAt = issuedAt . plus ( Duration . ofHours ( 1 ) ) ;
Instant refreshTokenExpiresAt = Instant . now ( ) . plus ( Duration . ofHours ( 1 ) ) ;
this . accessToken = new OAuth2AccessToken ( this . accessToken . getTokenType ( ) ,
this . accessToken . getTokenValue ( ) ,
issuedAt ,
accessTokenExpiresAt ) ;
this . function = new OAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientService ) ;
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken ( "refresh-token" , issuedAt , refreshTokenExpiresAt ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . github ,
"principalName" , this . accessToken , refreshToken ) ;
ClientRequest request = ClientRequest . create ( GET , URI . create ( "https://example.com" ) )
. attributes ( oauth2AuthorizedClient ( authorizedClient ) )
. build ( ) ;
TestingAuthenticationToken authentication = new TestingAuthenticationToken ( "test" , "this" ) ;
this . function . filter ( request , this . exchange )
. subscriberContext ( ReactiveSecurityContextHolder . withAuthentication ( authentication ) )
. block ( ) ;
verify ( this . authorizedClientService ) . saveAuthorizedClient ( any ( ) , eq ( authentication ) ) ;
List < ClientRequest > requests = this . exchange . getRequests ( ) ;
assertThat ( requests ) . hasSize ( 2 ) ;
ClientRequest request0 = requests . get ( 0 ) ;
assertThat ( request0 . headers ( ) . getFirst ( HttpHeaders . AUTHORIZATION ) ) . isEqualTo ( "Basic Y2xpZW50SWQ6Y2xpZW50U2VjcmV0" ) ;
assertThat ( request0 . url ( ) . toASCIIString ( ) ) . isEqualTo ( "https://github.com/login/oauth/access_token" ) ;
assertThat ( request0 . method ( ) ) . isEqualTo ( HttpMethod . POST ) ;
assertThat ( getBody ( request0 ) ) . isEqualTo ( "grant_type=refresh_token&refresh_token=refresh-token" ) ;
ClientRequest request1 = requests . get ( 1 ) ;
assertThat ( request1 . headers ( ) . getFirst ( HttpHeaders . AUTHORIZATION ) ) . isEqualTo ( "Bearer token-1" ) ;
assertThat ( request1 . url ( ) . toASCIIString ( ) ) . isEqualTo ( "https://example.com" ) ;
assertThat ( request1 . method ( ) ) . isEqualTo ( HttpMethod . GET ) ;
assertThat ( getBody ( request1 ) ) . isEmpty ( ) ;
}
@Test
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved ( ) {
when ( this . authorizedClientService . saveAuthorizedClient ( any ( ) , any ( ) ) ) . thenReturn ( Mono . empty ( ) ) ;
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse . withToken ( "token-1" )
. tokenType ( OAuth2AccessToken . TokenType . BEARER )
. expiresIn ( 3600 )
. refreshToken ( "refresh-1" )
. build ( ) ;
when ( this . exchange . getResponse ( ) . body ( any ( ) ) ) . thenReturn ( Mono . just ( response ) ) ;
Instant issuedAt = Instant . now ( ) . minus ( Duration . ofDays ( 1 ) ) ;
Instant accessTokenExpiresAt = issuedAt . plus ( Duration . ofHours ( 1 ) ) ;
Instant refreshTokenExpiresAt = Instant . now ( ) . plus ( Duration . ofHours ( 1 ) ) ;
this . accessToken = new OAuth2AccessToken ( this . accessToken . getTokenType ( ) ,
this . accessToken . getTokenValue ( ) ,
issuedAt ,
accessTokenExpiresAt ) ;
this . function = new OAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientService ) ;
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken ( "refresh-token" , issuedAt , refreshTokenExpiresAt ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . github ,
"principalName" , this . accessToken , refreshToken ) ;
ClientRequest request = ClientRequest . create ( GET , URI . create ( "https://example.com" ) )
. attributes ( oauth2AuthorizedClient ( authorizedClient ) )
. build ( ) ;
this . function . filter ( request , this . exchange )
. block ( ) ;
verify ( this . authorizedClientService ) . saveAuthorizedClient ( any ( ) , any ( ) ) ;
List < ClientRequest > requests = this . exchange . getRequests ( ) ;
assertThat ( requests ) . hasSize ( 2 ) ;
ClientRequest request0 = requests . get ( 0 ) ;
assertThat ( request0 . headers ( ) . getFirst ( HttpHeaders . AUTHORIZATION ) ) . isEqualTo ( "Basic Y2xpZW50SWQ6Y2xpZW50U2VjcmV0" ) ;
assertThat ( request0 . url ( ) . toASCIIString ( ) ) . isEqualTo ( "https://github.com/login/oauth/access_token" ) ;
assertThat ( request0 . method ( ) ) . isEqualTo ( HttpMethod . POST ) ;
assertThat ( getBody ( request0 ) ) . isEqualTo ( "grant_type=refresh_token&refresh_token=refresh-token" ) ;
ClientRequest request1 = requests . get ( 1 ) ;
assertThat ( request1 . headers ( ) . getFirst ( HttpHeaders . AUTHORIZATION ) ) . isEqualTo ( "Bearer token-1" ) ;
assertThat ( request1 . url ( ) . toASCIIString ( ) ) . isEqualTo ( "https://example.com" ) ;
assertThat ( request1 . method ( ) ) . isEqualTo ( HttpMethod . GET ) ;
assertThat ( getBody ( request1 ) ) . isEmpty ( ) ;
}
@Test
public void filterWhenRefreshTokenNullThenShouldRefreshFalse ( ) {
this . function = new OAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientService ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . github ,
"principalName" , this . accessToken ) ;
ClientRequest request = ClientRequest . create ( GET , URI . create ( "https://example.com" ) )
. attributes ( oauth2AuthorizedClient ( authorizedClient ) )
. build ( ) ;
this . function . filter ( request , this . exchange ) . block ( ) ;
List < ClientRequest > requests = this . exchange . getRequests ( ) ;
assertThat ( requests ) . hasSize ( 1 ) ;
ClientRequest request0 = requests . get ( 0 ) ;
assertThat ( request0 . headers ( ) . getFirst ( HttpHeaders . AUTHORIZATION ) ) . isEqualTo ( "Bearer token-0" ) ;
assertThat ( request0 . url ( ) . toASCIIString ( ) ) . isEqualTo ( "https://example.com" ) ;
assertThat ( request0 . method ( ) ) . isEqualTo ( HttpMethod . GET ) ;
assertThat ( getBody ( request0 ) ) . isEmpty ( ) ;
}
@Test
public void filterWhenNotExpiredThenShouldRefreshFalse ( ) {
this . function = new OAuth2AuthorizedClientExchangeFilterFunction ( this . authorizedClientService ) ;
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken ( "refresh-token" , this . accessToken . getIssuedAt ( ) , this . accessToken . getExpiresAt ( ) ) ;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient ( this . github ,
"principalName" , this . accessToken , refreshToken ) ;
ClientRequest request = ClientRequest . create ( GET , URI . create ( "https://example.com" ) )
. attributes ( oauth2AuthorizedClient ( authorizedClient ) )
. build ( ) ;
this . function . filter ( request , this . exchange ) . block ( ) ;
List < ClientRequest > requests = this . exchange . getRequests ( ) ;
assertThat ( requests ) . hasSize ( 1 ) ;
ClientRequest request0 = requests . get ( 0 ) ;
assertThat ( request0 . headers ( ) . getFirst ( HttpHeaders . AUTHORIZATION ) ) . isEqualTo ( "Bearer token-0" ) ;
assertThat ( request0 . url ( ) . toASCIIString ( ) ) . isEqualTo ( "https://example.com" ) ;
assertThat ( request0 . method ( ) ) . isEqualTo ( HttpMethod . GET ) ;
assertThat ( getBody ( request0 ) ) . isEmpty ( ) ;
}
private static String getBody ( ClientRequest request ) {
final List < HttpMessageWriter < ? > > messageWriters = new ArrayList < > ( ) ;
messageWriters . add ( new EncoderHttpMessageWriter < > ( new ByteBufferEncoder ( ) ) ) ;
messageWriters . add ( new EncoderHttpMessageWriter < > ( CharSequenceEncoder . textPlainOnly ( ) ) ) ;
messageWriters . add ( new ResourceHttpMessageWriter ( ) ) ;
messageWriters . add ( new EncoderHttpMessageWriter < > ( new Jaxb2XmlEncoder ( ) ) ) ;
Jackson2JsonEncoder jsonEncoder = new Jackson2JsonEncoder ( ) ;
messageWriters . add ( new EncoderHttpMessageWriter < > ( jsonEncoder ) ) ;
messageWriters . add ( new ServerSentEventHttpMessageWriter ( jsonEncoder ) ) ;
messageWriters . add ( new FormHttpMessageWriter ( ) ) ;
messageWriters . add ( new EncoderHttpMessageWriter < > ( CharSequenceEncoder . allMimeTypes ( ) ) ) ;
messageWriters . add ( new MultipartHttpMessageWriter ( messageWriters ) ) ;
BodyInserter . Context context = new BodyInserter . Context ( ) {
@Override
public List < HttpMessageWriter < ? > > messageWriters ( ) {
return messageWriters ;
}
@Override
public Optional < ServerHttpRequest > serverRequest ( ) {
return Optional . empty ( ) ;
}
@Override
public Map < String , Object > hints ( ) {
return new HashMap < > ( ) ;
}
} ;
MockClientHttpRequest body = new MockClientHttpRequest ( HttpMethod . GET , "/" ) ;
request . body ( ) . insert ( body , context ) . block ( ) ;
return body . getBodyAsString ( ) . block ( ) ;
}
}