@ -16,11 +16,16 @@
package org.springframework.security.oauth2.client.endpoint ;
package org.springframework.security.oauth2.client.endpoint ;
import java.nio.charset.StandardCharsets ;
import java.time.Instant ;
import java.time.Instant ;
import java.util.Collections ;
import java.util.Collections ;
import java.util.HashMap ;
import java.util.HashMap ;
import java.util.Map ;
import java.util.Map ;
import java.util.function.Function ;
import javax.crypto.spec.SecretKeySpec ;
import com.nimbusds.jose.jwk.JWK ;
import okhttp3.mockwebserver.MockResponse ;
import okhttp3.mockwebserver.MockResponse ;
import okhttp3.mockwebserver.MockWebServer ;
import okhttp3.mockwebserver.MockWebServer ;
import okhttp3.mockwebserver.RecordedRequest ;
import okhttp3.mockwebserver.RecordedRequest ;
@ -36,6 +41,7 @@ import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage ;
import org.springframework.http.ReactiveHttpInputMessage ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations ;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations ;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod ;
import org.springframework.security.oauth2.core.OAuth2AccessToken ;
import org.springframework.security.oauth2.core.OAuth2AccessToken ;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException ;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse ;
@ -44,6 +50,10 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse ;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames ;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames ;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses ;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses ;
import org.springframework.security.oauth2.jose.TestJwks ;
import org.springframework.security.oauth2.jose.TestKeys ;
import org.springframework.util.LinkedMultiValueMap ;
import org.springframework.util.MultiValueMap ;
import org.springframework.web.reactive.function.BodyExtractor ;
import org.springframework.web.reactive.function.BodyExtractor ;
import org.springframework.web.reactive.function.client.WebClient ;
import org.springframework.web.reactive.function.client.WebClient ;
@ -112,6 +122,75 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
assertThat ( accessTokenResponse . getAdditionalParameters ( ) ) . containsEntry ( "custom_parameter_2" , "custom-value-2" ) ;
assertThat ( accessTokenResponse . getAdditionalParameters ( ) ) . containsEntry ( "custom_parameter_2" , "custom-value-2" ) ;
}
}
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent ( ) throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n" ;
// @formatter:on
this . server . enqueue ( jsonResponse ( accessTokenSuccessResponse ) ) ;
// @formatter:off
ClientRegistration clientRegistration = this . clientRegistration
. clientAuthenticationMethod ( ClientAuthenticationMethod . CLIENT_SECRET_JWT )
. clientSecret ( TestKeys . DEFAULT_ENCODED_SECRET_KEY )
. build ( ) ;
// @formatter:on
// Configure Jwt client authentication converter
SecretKeySpec secretKey = new SecretKeySpec (
clientRegistration . getClientSecret ( ) . getBytes ( StandardCharsets . UTF_8 ) , "HmacSHA256" ) ;
JWK jwk = TestJwks . jwk ( secretKey ) . build ( ) ;
Function < ClientRegistration , JWK > jwkResolver = ( registration ) - > jwk ;
configureJwtClientAuthenticationConverter ( jwkResolver ) ;
this . tokenResponseClient . getTokenResponse ( authorizationCodeGrantRequest ( clientRegistration ) ) . block ( ) ;
RecordedRequest actualRequest = this . server . takeRequest ( ) ;
assertThat ( actualRequest . getHeader ( HttpHeaders . AUTHORIZATION ) ) . isNull ( ) ;
assertThat ( actualRequest . getBody ( ) . readUtf8 ( ) ) . contains ( "grant_type=authorization_code" ,
"client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer" ,
"client_assertion=" ) ;
}
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent ( ) throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n" ;
// @formatter:on
this . server . enqueue ( jsonResponse ( accessTokenSuccessResponse ) ) ;
// @formatter:off
ClientRegistration clientRegistration = this . clientRegistration
. clientAuthenticationMethod ( ClientAuthenticationMethod . PRIVATE_KEY_JWT )
. build ( ) ;
// @formatter:on
// Configure Jwt client authentication converter
JWK jwk = TestJwks . DEFAULT_RSA_JWK ;
Function < ClientRegistration , JWK > jwkResolver = ( registration ) - > jwk ;
configureJwtClientAuthenticationConverter ( jwkResolver ) ;
this . tokenResponseClient . getTokenResponse ( authorizationCodeGrantRequest ( clientRegistration ) ) . block ( ) ;
RecordedRequest actualRequest = this . server . takeRequest ( ) ;
assertThat ( actualRequest . getHeader ( HttpHeaders . AUTHORIZATION ) ) . isNull ( ) ;
assertThat ( actualRequest . getBody ( ) . readUtf8 ( ) ) . contains ( "grant_type=authorization_code" ,
"client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer" ,
"client_assertion=" ) ;
}
private void configureJwtClientAuthenticationConverter ( Function < ClientRegistration , JWK > jwkResolver ) {
NimbusJwtClientAuthenticationParametersConverter < OAuth2AuthorizationCodeGrantRequest > jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter < > (
jwkResolver ) ;
this . tokenResponseClient . addParametersConverter ( jwtClientAuthenticationConverter ) ;
}
// @Test
// @Test
// public void
// public void
// getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() throws
// getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() throws
@ -261,7 +340,10 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
}
}
private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest ( ) {
private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest ( ) {
ClientRegistration registration = this . clientRegistration . build ( ) ;
return authorizationCodeGrantRequest ( this . clientRegistration . build ( ) ) ;
}
private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest ( ClientRegistration registration ) {
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest . authorizationCode ( )
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest . authorizationCode ( )
. clientId ( registration . getClientId ( ) ) . state ( "state" )
. clientId ( registration . getClientId ( ) ) . state ( "state" )
. authorizationUri ( registration . getProviderDetails ( ) . getAuthorizationUri ( ) )
. authorizationUri ( registration . getProviderDetails ( ) . getAuthorizationUri ( ) )
@ -414,6 +496,67 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
. isEqualTo ( "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=" ) ;
. isEqualTo ( "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=" ) ;
}
}
@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException ( ) {
assertThatIllegalArgumentException ( ) . isThrownBy ( ( ) - > this . tokenResponseClient . setParametersConverter ( null ) )
. withMessage ( "parametersConverter cannot be null" ) ;
}
@Test
public void addParametersConverterWhenNullThenThrowIllegalArgumentException ( ) {
assertThatIllegalArgumentException ( ) . isThrownBy ( ( ) - > this . tokenResponseClient . addParametersConverter ( null ) )
. withMessage ( "parametersConverter cannot be null" ) ;
}
@Test
public void convertWhenParametersConverterAddedThenCalled ( ) throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest ( ) ;
Converter < OAuth2AuthorizationCodeGrantRequest , MultiValueMap < String , String > > addedParametersConverter = mock (
Converter . class ) ;
MultiValueMap < String , String > parameters = new LinkedMultiValueMap < > ( ) ;
parameters . add ( "custom-parameter-name" , "custom-parameter-value" ) ;
given ( addedParametersConverter . convert ( request ) ) . willReturn ( parameters ) ;
this . tokenResponseClient . addParametersConverter ( addedParametersConverter ) ;
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}" ;
// @formatter:on
this . server . enqueue ( jsonResponse ( accessTokenSuccessResponse ) ) ;
this . tokenResponseClient . getTokenResponse ( request ) . block ( ) ;
verify ( addedParametersConverter ) . convert ( request ) ;
RecordedRequest actualRequest = this . server . takeRequest ( ) ;
assertThat ( actualRequest . getBody ( ) . readUtf8 ( ) ) . contains ( "grant_type=authorization_code" ,
"custom-parameter-name=custom-parameter-value" ) ;
}
@Test
public void convertWhenParametersConverterSetThenCalled ( ) throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest ( ) ;
Converter < OAuth2AuthorizationCodeGrantRequest , MultiValueMap < String , String > > parametersConverter = mock (
Converter . class ) ;
MultiValueMap < String , String > parameters = new LinkedMultiValueMap < > ( ) ;
parameters . add ( "custom-parameter-name" , "custom-parameter-value" ) ;
given ( parametersConverter . convert ( request ) ) . willReturn ( parameters ) ;
this . tokenResponseClient . setParametersConverter ( parametersConverter ) ;
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}" ;
// @formatter:on
this . server . enqueue ( jsonResponse ( accessTokenSuccessResponse ) ) ;
this . tokenResponseClient . getTokenResponse ( request ) . block ( ) ;
verify ( parametersConverter ) . convert ( request ) ;
RecordedRequest actualRequest = this . server . takeRequest ( ) ;
assertThat ( actualRequest . getBody ( ) . readUtf8 ( ) ) . contains ( "custom-parameter-name=custom-parameter-value" ) ;
}
// gh-10260
// gh-10260
@Test
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse ( ) {
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse ( ) {