@ -15,16 +15,14 @@
@@ -15,16 +15,14 @@
* /
package org.springframework.security.oauth2.client.userinfo ;
import org.junit.Before ;
import org.junit.Test ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpMethod ;
import org.springframework.http.MediaType ;
import org.springframework.http.RequestEntity ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations ;
import org.springframework.security.oauth2.core.AuthenticationMethod ;
import org.springframework.security.oauth2.core.AuthorizationGrantType ;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod ;
import org.springframework.security.oauth2.core.OAuth2AccessToken ;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
import org.springframework.util.MultiValueMap ;
@ -43,35 +41,15 @@ import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VAL
@@ -43,35 +41,15 @@ import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VAL
* /
public class OAuth2UserRequestEntityConverterTests {
private OAuth2UserRequestEntityConverter converter = new OAuth2UserRequestEntityConverter ( ) ;
private OAuth2UserRequest userRequest ;
@Before
public void setup ( ) {
ClientRegistration clientRegistration = ClientRegistration . withRegistrationId ( "registration-1" )
. clientId ( "client-1" )
. clientSecret ( "secret" )
. clientAuthenticationMethod ( ClientAuthenticationMethod . BASIC )
. authorizationGrantType ( AuthorizationGrantType . AUTHORIZATION_CODE )
. redirectUriTemplate ( "https://client.com/callback/client-1" )
. scope ( "read" , "write" )
. authorizationUri ( "https://provider.com/oauth2/authorize" )
. tokenUri ( "https://provider.com/oauth2/token" )
. userInfoUri ( "https://provider.com/user" )
. userInfoAuthenticationMethod ( AuthenticationMethod . HEADER )
. userNameAttributeName ( "id" )
. build ( ) ;
OAuth2AccessToken accessToken = new OAuth2AccessToken (
OAuth2AccessToken . TokenType . BEARER , "access-token-1234" , Instant . now ( ) ,
Instant . now ( ) . plusSeconds ( 3600 ) , new LinkedHashSet < > ( Arrays . asList ( "read" , "write" ) ) ) ;
this . userRequest = new OAuth2UserRequest ( clientRegistration , accessToken ) ;
}
@SuppressWarnings ( "unchecked" )
@Test
public void convertWhenAuthenticationMethodHeaderThenGetRequest ( ) {
RequestEntity < ? > requestEntity = this . converter . convert ( this . userRequest ) ;
ClientRegistration clientRegistration = TestClientRegistrations . clientRegistration ( ) . build ( ) ;
OAuth2UserRequest userRequest = new OAuth2UserRequest (
clientRegistration , this . createAccessToken ( ) ) ;
ClientRegistration clientRegistration = this . userRequest . getClientRegistration ( ) ;
RequestEntity < ? > requestEntity = this . converter . convert ( userRequest ) ;
assertThat ( requestEntity . getMethod ( ) ) . isEqualTo ( HttpMethod . GET ) ;
assertThat ( requestEntity . getUrl ( ) . toASCIIString ( ) ) . isEqualTo (
@ -80,17 +58,17 @@ public class OAuth2UserRequestEntityConverterTests {
@@ -80,17 +58,17 @@ public class OAuth2UserRequestEntityConverterTests {
HttpHeaders headers = requestEntity . getHeaders ( ) ;
assertThat ( headers . getAccept ( ) ) . contains ( MediaType . APPLICATION_JSON_UTF8 ) ;
assertThat ( headers . getFirst ( HttpHeaders . AUTHORIZATION ) ) . isEqualTo (
"Bearer " + this . userRequest . getAccessToken ( ) . getTokenValue ( ) ) ;
"Bearer " + userRequest . getAccessToken ( ) . getTokenValue ( ) ) ;
}
@SuppressWarnings ( "unchecked" )
@Test
public void convertWhenAuthenticationMethodFormThenPostRequest ( ) {
ClientRegistration clientRegistration = this . from ( this . userRequest . getClientRegistration ( ) )
ClientRegistration clientRegistration = TestClientRegistrations . clientRegistration ( )
. userInfoAuthenticationMethod ( AuthenticationMethod . FORM )
. build ( ) ;
OAuth2UserRequest userRequest = new OAuth2UserRequest (
clientRegistration , this . userRequest . get AccessToken( ) ) ;
clientRegistration , this . create AccessToken( ) ) ;
RequestEntity < ? > requestEntity = this . converter . convert ( userRequest ) ;
@ -105,21 +83,13 @@ public class OAuth2UserRequestEntityConverterTests {
@@ -105,21 +83,13 @@ public class OAuth2UserRequestEntityConverterTests {
MultiValueMap < String , String > formParameters = ( MultiValueMap < String , String > ) requestEntity . getBody ( ) ;
assertThat ( formParameters . getFirst ( OAuth2ParameterNames . ACCESS_TOKEN ) ) . isEqualTo (
this . userRequest . getAccessToken ( ) . getTokenValue ( ) ) ;
userRequest . getAccessToken ( ) . getTokenValue ( ) ) ;
}
private ClientRegistration . Builder from ( ClientRegistration registration ) {
return ClientRegistration . withRegistrationId ( registration . getRegistrationId ( ) )
. clientId ( registration . getClientId ( ) )
. clientSecret ( registration . getClientSecret ( ) )
. clientAuthenticationMethod ( registration . getClientAuthenticationMethod ( ) )
. authorizationGrantType ( registration . getAuthorizationGrantType ( ) )
. redirectUriTemplate ( registration . getRedirectUriTemplate ( ) )
. scope ( registration . getScopes ( ) )
. authorizationUri ( registration . getProviderDetails ( ) . getAuthorizationUri ( ) )
. tokenUri ( registration . getProviderDetails ( ) . getTokenUri ( ) )
. userInfoUri ( registration . getProviderDetails ( ) . getUserInfoEndpoint ( ) . getUri ( ) )
. userNameAttributeName ( registration . getProviderDetails ( ) . getUserInfoEndpoint ( ) . getUserNameAttributeName ( ) )
. clientName ( registration . getClientName ( ) ) ;
private OAuth2AccessToken createAccessToken ( ) {
OAuth2AccessToken accessToken = new OAuth2AccessToken (
OAuth2AccessToken . TokenType . BEARER , "access-token-1234" , Instant . now ( ) ,
Instant . now ( ) . plusSeconds ( 3600 ) , new LinkedHashSet < > ( Arrays . asList ( "read" , "write" ) ) ) ;
return accessToken ;
}
}