@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2019 the original author or authors .
* Copyright 2002 - 2024 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -18,35 +18,40 @@ package org.springframework.security.oauth2.client;
@@ -18,35 +18,40 @@ package org.springframework.security.oauth2.client;
import java.time.Duration ;
import java.time.Instant ;
import java.util.List ;
import org.junit.jupiter.api.BeforeEach ;
import org.junit.jupiter.api.Test ;
import org.springframework.http.HttpStatus ;
import org.springframework.http.RequestEntity ;
import org.springframework.http.ResponseEntity ;
import org.springframework.core.io.ClassPathResource ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.MediaType ;
import org.springframework.http.converter.FormHttpMessageConverter ;
import org.springframework.security.authentication.TestingAuthenticationToken ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient ;
import org.springframework.security.oauth2.client.endpoint.DefaultPasswordTokenResponseClient ;
import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient ;
import org.springframework.security.oauth2.client.endpoint.RestClientClientCredentialsTokenResponseClient ;
import org.springframework.security.oauth2.client.endpoint.RestClientRefreshTokenTokenResponseClient ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations ;
import org.springframework.security.oauth2.core.OAuth2AccessToken ;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse ;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses ;
import org.springframework.web.client.RestOperations ;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter ;
import org.springframework.test.web.client.ExpectedCount ;
import org.springframework.test.web.client.MockRestServiceServer ;
import org.springframework.web.client.RestClient ;
import org.springframework.web.client.RestTemplate ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType ;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.eq ;
import static org.mockito.BDDMockito.given ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.times ;
import static org.mockito.Mockito.verify ;
import static org.springframework.test.web.client.ExpectedCount.once ;
import static org.springframework.test.web.client.ExpectedCount.times ;
import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo ;
import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess ;
/ * *
* Tests for { @link OAuth2AuthorizedClientProviderBuilder } .
@ -55,29 +60,30 @@ import static org.mockito.Mockito.verify;
@@ -55,29 +60,30 @@ import static org.mockito.Mockito.verify;
* /
public class OAuth2AuthorizedClientProviderBuilderTests {
private RestOperations accessToken Client ;
private RestClientClientCredentialsTokenResponseClient clientCredentialsTokenResponse Client ;
private DefaultClientCredentialsTokenResponseClient clientCredentialsTokenResponseClient ;
private DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient ;
private RestClientRefreshTokenTokenResponseClient refreshTokenTokenResponseClient ;
private DefaultPasswordTokenResponseClient passwordTokenResponseClient ;
private Authentication principal ;
@SuppressWarnings ( "unchecked" )
private MockRestServiceServer server ;
@BeforeEach
public void setup ( ) {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses . accessTokenResponse ( ) . build ( ) ;
this . accessTokenClient = mock ( RestOperations . class ) ;
given ( this . accessTokenClient . exchange ( any ( RequestEntity . class ) , eq ( OAuth2AccessTokenResponse . class ) ) )
. willReturn ( new ResponseEntity ( accessTokenResponse , HttpStatus . OK ) ) ;
this . refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient ( ) ;
this . refreshTokenTokenResponseClient . setRestOperations ( this . accessTokenClient ) ;
this . clientCredentialsTokenResponseClient = new DefaultClientCredentialsTokenResponseClient ( ) ;
this . clientCredentialsTokenResponseClient . setRestOperations ( this . accessTokenClient ) ;
// TODO: Use of RestTemplate in these tests can be removed when
// DefaultPasswordTokenResponseClient is removed.
RestTemplate accessTokenClient = new RestTemplate (
List . of ( new FormHttpMessageConverter ( ) , new OAuth2AccessTokenResponseHttpMessageConverter ( ) ) ) ;
this . server = MockRestServiceServer . bindTo ( accessTokenClient ) . build ( ) ;
RestClient restClient = RestClient . create ( accessTokenClient ) ;
this . refreshTokenTokenResponseClient = new RestClientRefreshTokenTokenResponseClient ( ) ;
this . refreshTokenTokenResponseClient . setRestClient ( restClient ) ;
this . clientCredentialsTokenResponseClient = new RestClientClientCredentialsTokenResponseClient ( ) ;
this . clientCredentialsTokenResponseClient . setRestClient ( restClient ) ;
this . passwordTokenResponseClient = new DefaultPasswordTokenResponseClient ( ) ;
this . passwordTokenResponseClient . setRestOperations ( this . accessTokenClient ) ;
this . passwordTokenResponseClient . setRestOperations ( accessTokenClient ) ;
this . principal = new TestingAuthenticationToken ( "principal" , "password" ) ;
}
@ -104,6 +110,8 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@@ -104,6 +110,8 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@Test
public void buildWhenRefreshTokenProviderThenProviderReauthorizes ( ) {
mockAccessTokenResponse ( once ( ) ) ;
OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder . builder ( )
. refreshToken ( ( configurer ) - > configurer . accessTokenResponseClient ( this . refreshTokenTokenResponseClient ) )
. build ( ) ;
@ -118,11 +126,13 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@@ -118,11 +126,13 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider . authorize ( authorizationContext ) ;
assertThat ( reauthorizedClient ) . isNotNull ( ) ;
verify ( this . accessTokenClient ) . exchange ( any ( RequestEntity . class ) , eq ( OAuth2AccessTokenResponse . class ) ) ;
this . server . verify ( ) ;
}
@Test
public void buildWhenClientCredentialsProviderThenProviderAuthorizes ( ) {
mockAccessTokenResponse ( once ( ) ) ;
OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder . builder ( )
. clientCredentials (
( configurer ) - > configurer . accessTokenResponseClient ( this . clientCredentialsTokenResponseClient ) )
@ -135,11 +145,13 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@@ -135,11 +145,13 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
// @formatter:on
OAuth2AuthorizedClient authorizedClient = authorizedClientProvider . authorize ( authorizationContext ) ;
assertThat ( authorizedClient ) . isNotNull ( ) ;
verify ( this . accessTokenClient ) . exchange ( any ( RequestEntity . class ) , eq ( OAuth2AccessTokenResponse . class ) ) ;
this . server . verify ( ) ;
}
@Test
public void buildWhenPasswordProviderThenProviderAuthorizes ( ) {
mockAccessTokenResponse ( once ( ) ) ;
// @formatter:off
OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder . builder ( )
. password ( ( configurer ) - > configurer . accessTokenResponseClient ( this . passwordTokenResponseClient ) )
@ -153,11 +165,13 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@@ -153,11 +165,13 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
// @formatter:on
OAuth2AuthorizedClient authorizedClient = authorizedClientProvider . authorize ( authorizationContext ) ;
assertThat ( authorizedClient ) . isNotNull ( ) ;
verify ( this . accessTokenClient ) . exchange ( any ( RequestEntity . class ) , eq ( OAuth2AccessTokenResponse . class ) ) ;
this . server . verify ( ) ;
}
@Test
public void buildWhenAllProvidersThenProvidersAuthorize ( ) {
mockAccessTokenResponse ( times ( 3 ) ) ;
OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder . builder ( )
. authorizationCode ( )
. refreshToken ( ( configurer ) - > configurer . accessTokenResponseClient ( this . refreshTokenTokenResponseClient ) )
@ -184,8 +198,6 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@@ -184,8 +198,6 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
. build ( ) ;
OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider . authorize ( refreshTokenContext ) ;
assertThat ( reauthorizedClient ) . isNotNull ( ) ;
verify ( this . accessTokenClient , times ( 1 ) ) . exchange ( any ( RequestEntity . class ) ,
eq ( OAuth2AccessTokenResponse . class ) ) ;
// client_credentials
// @formatter:off
OAuth2AuthorizationContext clientCredentialsContext = OAuth2AuthorizationContext
@ -195,8 +207,6 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@@ -195,8 +207,6 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
// @formatter:on
authorizedClient = authorizedClientProvider . authorize ( clientCredentialsContext ) ;
assertThat ( authorizedClient ) . isNotNull ( ) ;
verify ( this . accessTokenClient , times ( 2 ) ) . exchange ( any ( RequestEntity . class ) ,
eq ( OAuth2AccessTokenResponse . class ) ) ;
// password
// @formatter:off
OAuth2AuthorizationContext passwordContext = OAuth2AuthorizationContext
@ -208,8 +218,7 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@@ -208,8 +218,7 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
// @formatter:on
authorizedClient = authorizedClientProvider . authorize ( passwordContext ) ;
assertThat ( authorizedClient ) . isNotNull ( ) ;
verify ( this . accessTokenClient , times ( 3 ) ) . exchange ( any ( RequestEntity . class ) ,
eq ( OAuth2AccessTokenResponse . class ) ) ;
this . server . verify ( ) ;
}
@Test
@ -234,4 +243,10 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@@ -234,4 +243,10 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
return new OAuth2AccessToken ( OAuth2AccessToken . TokenType . BEARER , "access-token-1234" , issuedAt , expiresAt ) ;
}
private void mockAccessTokenResponse ( ExpectedCount expectedCount ) {
this . server . expect ( expectedCount , requestTo ( "https://example.com/login/oauth/access_token" ) )
. andRespond ( withSuccess ( ) . header ( HttpHeaders . CONTENT_TYPE , MediaType . APPLICATION_JSON_VALUE )
. body ( new ClassPathResource ( "access-token-response.json" ) ) ) ;
}
}