@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License .
* limitations under the License .
* /
* /
package org.springframework.security.oauth2.client.web.server ;
package org.springframework.security.oauth2.client.web ;
import org.junit.Before ;
import org.junit.Before ;
import org.junit.Test ;
import org.junit.Test ;
@ -24,11 +24,13 @@ import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.TestingAuthenticationToken ;
import org.springframework.security.authentication.TestingAuthenticationToken ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext ;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext ;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest ;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient ;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient ;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider ;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository ;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository ;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations ;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations ;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository ;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens ;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens ;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens ;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens ;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
@ -48,16 +50,16 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.* ;
import static org.mockito.Mockito.* ;
/ * *
/ * *
* Tests for { @link DefaultServer OAuth2AuthorizedClientManager } .
* Tests for { @link DefaultReactive OAuth2AuthorizedClientManager } .
*
*
* @author Joe Grandja
* @author Joe Grandja
* /
* /
public class DefaultServer OAuth2AuthorizedClientManagerTests {
public class DefaultReactive OAuth2AuthorizedClientManagerTests {
private ReactiveClientRegistrationRepository clientRegistrationRepository ;
private ReactiveClientRegistrationRepository clientRegistrationRepository ;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository ;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository ;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider ;
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider ;
private Function contextAttributesMapper ;
private Function contextAttributesMapper ;
private DefaultServer OAuth2AuthorizedClientManager authorizedClientManager ;
private DefaultReactive OAuth2AuthorizedClientManager authorizedClientManager ;
private ClientRegistration clientRegistration ;
private ClientRegistration clientRegistration ;
private Authentication principal ;
private Authentication principal ;
private OAuth2AuthorizedClient authorizedClient ;
private OAuth2AuthorizedClient authorizedClient ;
@ -77,7 +79,7 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
when ( this . authorizedClientProvider . authorize ( any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . empty ( ) ) ;
when ( this . authorizedClientProvider . authorize ( any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . empty ( ) ) ;
this . contextAttributesMapper = mock ( Function . class ) ;
this . contextAttributesMapper = mock ( Function . class ) ;
when ( this . contextAttributesMapper . apply ( any ( ) ) ) . thenReturn ( Mono . just ( Collections . emptyMap ( ) ) ) ;
when ( this . contextAttributesMapper . apply ( any ( ) ) ) . thenReturn ( Mono . just ( Collections . emptyMap ( ) ) ) ;
this . authorizedClientManager = new DefaultServer OAuth2AuthorizedClientManager (
this . authorizedClientManager = new DefaultReactive OAuth2AuthorizedClientManager (
this . clientRegistrationRepository , this . authorizedClientRepository ) ;
this . clientRegistrationRepository , this . authorizedClientRepository ) ;
this . authorizedClientManager . setAuthorizedClientProvider ( this . authorizedClientProvider ) ;
this . authorizedClientManager . setAuthorizedClientProvider ( this . authorizedClientProvider ) ;
this . authorizedClientManager . setContextAttributesMapper ( this . contextAttributesMapper ) ;
this . authorizedClientManager . setContextAttributesMapper ( this . contextAttributesMapper ) ;
@ -91,14 +93,14 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
@Test
@Test
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException ( ) {
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException ( ) {
assertThatThrownBy ( ( ) - > new DefaultServer OAuth2AuthorizedClientManager ( null , this . authorizedClientRepository ) )
assertThatThrownBy ( ( ) - > new DefaultReactive OAuth2AuthorizedClientManager ( null , this . authorizedClientRepository ) )
. isInstanceOf ( IllegalArgumentException . class )
. isInstanceOf ( IllegalArgumentException . class )
. hasMessage ( "clientRegistrationRepository cannot be null" ) ;
. hasMessage ( "clientRegistrationRepository cannot be null" ) ;
}
}
@Test
@Test
public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException ( ) {
public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException ( ) {
assertThatThrownBy ( ( ) - > new DefaultServer OAuth2AuthorizedClientManager ( this . clientRegistrationRepository , null ) )
assertThatThrownBy ( ( ) - > new DefaultReactive OAuth2AuthorizedClientManager ( this . clientRegistrationRepository , null ) )
. isInstanceOf ( IllegalArgumentException . class )
. isInstanceOf ( IllegalArgumentException . class )
. hasMessage ( "authorizedClientRepository cannot be null" ) ;
. hasMessage ( "authorizedClientRepository cannot be null" ) ;
}
}
@ -117,6 +119,16 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
. hasMessage ( "contextAttributesMapper cannot be null" ) ;
. hasMessage ( "contextAttributesMapper cannot be null" ) ;
}
}
@Test
public void authorizeWhenServerWebExchangeIsNullThenThrowIllegalArgumentException ( ) {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
. principal ( this . principal )
. build ( ) ;
assertThatThrownBy ( ( ) - > this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) )
. isInstanceOf ( IllegalArgumentException . class )
. hasMessage ( "serverWebExchange cannot be null" ) ;
}
@Test
@Test
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException ( ) {
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException ( ) {
assertThatThrownBy ( ( ) - > this . authorizedClientManager . authorize ( null ) . block ( ) )
assertThatThrownBy ( ( ) - > this . authorizedClientManager . authorize ( null ) . block ( ) )
@ -126,8 +138,10 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
@Test
@Test
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException ( ) {
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException ( ) {
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest (
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( "invalid-registration-id" )
"invalid-registration-id" , this . principal , this . serverWebExchange ) ;
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
assertThatThrownBy ( ( ) - > this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) )
assertThatThrownBy ( ( ) - > this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) )
. isInstanceOf ( IllegalArgumentException . class )
. isInstanceOf ( IllegalArgumentException . class )
. hasMessage ( "Could not find ClientRegistration with id 'invalid-registration-id'" ) ;
. hasMessage ( "Could not find ClientRegistration with id 'invalid-registration-id'" ) ;
@ -139,8 +153,10 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
when ( this . clientRegistrationRepository . findByRegistrationId (
when ( this . clientRegistrationRepository . findByRegistrationId (
eq ( this . clientRegistration . getRegistrationId ( ) ) ) ) . thenReturn ( Mono . just ( this . clientRegistration ) ) ;
eq ( this . clientRegistration . getRegistrationId ( ) ) ) ) . thenReturn ( Mono . just ( this . clientRegistration ) ) ;
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest (
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
this . clientRegistration . getRegistrationId ( ) , this . principal , this . serverWebExchange ) ;
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
@ -165,8 +181,10 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
when ( this . authorizedClientProvider . authorize (
when ( this . authorizedClientProvider . authorize (
any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( this . authorizedClient ) ) ;
any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( this . authorizedClient ) ) ;
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest (
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
this . clientRegistration . getRegistrationId ( ) , this . principal , this . serverWebExchange ) ;
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
@ -196,8 +214,10 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
when ( this . authorizedClientProvider . authorize ( any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( reauthorizedClient ) ) ;
when ( this . authorizedClientProvider . authorize ( any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( reauthorizedClient ) ) ;
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest (
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
this . clientRegistration . getRegistrationId ( ) , this . principal , this . serverWebExchange ) ;
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
@ -221,8 +241,9 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
when ( this . authorizedClientProvider . authorize ( any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( this . authorizedClient ) ) ;
when ( this . authorizedClientProvider . authorize ( any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( this . authorizedClient ) ) ;
// Set custom contextAttributesMapper capable of mapping the form parameters
// Set custom contextAttributesMapper capable of mapping the form parameters
this . authorizedClientManager . setContextAttributesMapper ( authorizeRequest - >
this . authorizedClientManager . setContextAttributesMapper ( authorizeRequest - > {
Mono . just ( authorizeRequest . getServerWebExchange ( ) )
ServerWebExchange serverWebExchange = authorizeRequest . getAttribute ( ServerWebExchange . class . getName ( ) ) ;
return Mono . just ( serverWebExchange )
. flatMap ( ServerWebExchange : : getFormData )
. flatMap ( ServerWebExchange : : getFormData )
. map ( formData - > {
. map ( formData - > {
Map < String , Object > contextAttributes = new HashMap < > ( ) ;
Map < String , Object > contextAttributes = new HashMap < > ( ) ;
@ -233,8 +254,8 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
contextAttributes . put ( OAuth2AuthorizationContext . PASSWORD_ATTRIBUTE_NAME , password ) ;
contextAttributes . put ( OAuth2AuthorizationContext . PASSWORD_ATTRIBUTE_NAME , password ) ;
}
}
return contextAttributes ;
return contextAttributes ;
} )
} ) ;
) ;
} ) ;
this . serverWebExchange = MockServerWebExchange . builder (
this . serverWebExchange = MockServerWebExchange . builder (
MockServerHttpRequest
MockServerHttpRequest
@ -243,8 +264,10 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
. body ( "username=username&password=password" ) )
. body ( "username=username&password=password" ) )
. build ( ) ;
. build ( ) ;
ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest (
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
this . clientRegistration . getRegistrationId ( ) , this . principal , this . serverWebExchange ) ;
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
@ -259,8 +282,10 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
@SuppressWarnings ( "unchecked" )
@SuppressWarnings ( "unchecked" )
@Test
@Test
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized ( ) {
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized ( ) {
ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest (
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest . withAuthorizedClient ( this . authorizedClient )
this . authorizedClient , this . principal , this . serverWebExchange ) ;
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( reauthorizeRequest ) . block ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( reauthorizeRequest ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
@ -285,8 +310,10 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
when ( this . authorizedClientProvider . authorize ( any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( reauthorizedClient ) ) ;
when ( this . authorizedClientProvider . authorize ( any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( reauthorizedClient ) ) ;
ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest (
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest . withAuthorizedClient ( this . authorizedClient )
this . authorizedClient , this . principal , this . serverWebExchange ) ;
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( reauthorizeRequest ) . block ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( reauthorizeRequest ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
@ -312,7 +339,7 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
// Override the mock with the default
// Override the mock with the default
this . authorizedClientManager . setContextAttributesMapper (
this . authorizedClientManager . setContextAttributesMapper (
new DefaultServer OAuth2AuthorizedClientManager . DefaultContextAttributesMapper ( ) ) ;
new DefaultReactive OAuth2AuthorizedClientManager . DefaultContextAttributesMapper ( ) ) ;
this . serverWebExchange = MockServerWebExchange . builder (
this . serverWebExchange = MockServerWebExchange . builder (
MockServerHttpRequest
MockServerHttpRequest
@ -320,8 +347,10 @@ public class DefaultServerOAuth2AuthorizedClientManagerTests {
. queryParam ( OAuth2ParameterNames . SCOPE , "read write" ) )
. queryParam ( OAuth2ParameterNames . SCOPE , "read write" ) )
. build ( ) ;
. build ( ) ;
ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest (
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest . withAuthorizedClient ( this . authorizedClient )
this . authorizedClient , this . principal , this . serverWebExchange ) ;
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
this . authorizedClientManager . authorize ( reauthorizeRequest ) . block ( ) ;
this . authorizedClientManager . authorize ( reauthorizeRequest ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;