@ -34,9 +34,9 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori
@@ -34,9 +34,9 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens ;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens ;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
import org.springframework.util.StringUtils ;
import org.springframework.web.server.ServerWebExchange ;
import reactor.core.publisher.Mono ;
import reactor.util.context.Context ;
import java.util.Collections ;
import java.util.HashMap ;
@ -64,6 +64,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -64,6 +64,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
private Authentication principal ;
private OAuth2AuthorizedClient authorizedClient ;
private MockServerWebExchange serverWebExchange ;
private Context context ;
private ArgumentCaptor < OAuth2AuthorizationContext > authorizationContextCaptor ;
@SuppressWarnings ( "unchecked" )
@ -75,6 +76,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -75,6 +76,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
this . authorizedClientRepository = mock ( ServerOAuth2AuthorizedClientRepository . class ) ;
when ( this . authorizedClientRepository . loadAuthorizedClient (
anyString ( ) , any ( Authentication . class ) , any ( ServerWebExchange . class ) ) ) . thenReturn ( Mono . empty ( ) ) ;
when ( this . authorizedClientRepository . saveAuthorizedClient (
any ( OAuth2AuthorizedClient . class ) , any ( Authentication . class ) , any ( ServerWebExchange . class ) ) ) . thenReturn ( Mono . empty ( ) ) ;
this . authorizedClientProvider = mock ( ReactiveOAuth2AuthorizedClientProvider . class ) ;
when ( this . authorizedClientProvider . authorize ( any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . empty ( ) ) ;
this . contextAttributesMapper = mock ( Function . class ) ;
@ -88,6 +91,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -88,6 +91,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
this . authorizedClient = new OAuth2AuthorizedClient ( this . clientRegistration , this . principal . getName ( ) ,
TestOAuth2AccessTokens . scopes ( "read" , "write" ) , TestOAuth2RefreshTokens . refreshToken ( ) ) ;
this . serverWebExchange = MockServerWebExchange . builder ( MockServerHttpRequest . get ( "/" ) ) . build ( ) ;
this . context = Context . of ( ServerWebExchange . class , this . serverWebExchange ) ;
this . authorizationContextCaptor = ArgumentCaptor . forClass ( OAuth2AuthorizationContext . class ) ;
}
@ -119,16 +123,6 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -119,16 +123,6 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
. hasMessage ( "contextAttributesMapper cannot be null" ) ;
}
@Test
public void authorizeWhenServerWebExchangeIsNullThenThrowIllegalArgumentException ( ) {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
. principal ( this . principal )
. build ( ) ;
assertThatThrownBy ( ( ) - > this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) )
. isInstanceOf ( IllegalArgumentException . class )
. hasMessage ( "serverWebExchange cannot be null" ) ;
}
@Test
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException ( ) {
assertThatThrownBy ( ( ) - > this . authorizedClientManager . authorize ( null ) . block ( ) )
@ -140,9 +134,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -140,9 +134,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException ( ) {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( "invalid-registration-id" )
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
assertThatThrownBy ( ( ) - > this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) )
assertThatThrownBy ( ( ) - > this . authorizedClientManager . authorize ( authorizeRequest ) . subscriberContext ( this . context ) . block ( ) )
. isInstanceOf ( IllegalArgumentException . class )
. hasMessage ( "Could not find ClientRegistration with id 'invalid-registration-id'" ) ;
}
@ -155,9 +148,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -155,9 +148,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest )
. subscriberContext ( this . context ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . contextAttributesMapper ) . apply ( eq ( authorizeRequest ) ) ;
@ -168,8 +161,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -168,8 +161,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
assertThat ( authorizationContext . getPrincipal ( ) ) . isEqualTo ( this . principal ) ;
assertThat ( authorizedClient ) . isNull ( ) ;
verify ( this . authorizedClientRepository , never ( ) ) . saveAuthorizedClient (
any ( OAuth2AuthorizedClient . class ) , eq ( this . principal ) , eq ( this . serverWebExchange ) ) ;
verify ( this . authorizedClientRepository , never ( ) ) . saveAuthorizedClient ( any ( ) , any ( ) , any ( ) ) ;
}
@SuppressWarnings ( "unchecked" )
@ -177,15 +169,14 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -177,15 +169,14 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized ( ) {
when ( this . clientRegistrationRepository . findByRegistrationId (
eq ( this . clientRegistration . getRegistrationId ( ) ) ) ) . thenReturn ( Mono . just ( this . clientRegistration ) ) ;
when ( this . authorizedClientProvider . authorize (
any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( this . authorizedClient ) ) ;
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest )
. subscriberContext ( this . context ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . contextAttributesMapper ) . apply ( eq ( authorizeRequest ) ) ;
@ -200,6 +191,31 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -200,6 +191,31 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
eq ( this . authorizedClient ) , eq ( this . principal ) , eq ( this . serverWebExchange ) ) ;
}
@Test
public void authorizeWhenNotAuthorizedAndSupportedProviderAndExchangeUnavailableThenAuthorizedButNotSaved ( ) {
when ( this . clientRegistrationRepository . findByRegistrationId (
eq ( this . clientRegistration . getRegistrationId ( ) ) ) ) . thenReturn ( Mono . just ( this . clientRegistration ) ) ;
when ( this . authorizedClientProvider . authorize (
any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( this . authorizedClient ) ) ;
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
. principal ( this . principal )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . contextAttributesMapper ) . apply ( eq ( authorizeRequest ) ) ;
OAuth2AuthorizationContext authorizationContext = this . authorizationContextCaptor . getValue ( ) ;
assertThat ( authorizationContext . getClientRegistration ( ) ) . isEqualTo ( this . clientRegistration ) ;
assertThat ( authorizationContext . getAuthorizedClient ( ) ) . isNull ( ) ;
assertThat ( authorizationContext . getPrincipal ( ) ) . isEqualTo ( this . principal ) ;
assertThat ( authorizedClient ) . isSameAs ( this . authorizedClient ) ;
verify ( this . authorizedClientRepository , never ( ) ) . saveAuthorizedClient ( any ( ) , any ( ) , any ( ) ) ;
}
@SuppressWarnings ( "unchecked" )
@Test
public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized ( ) {
@ -216,9 +232,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -216,9 +232,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( authorizeRequest )
. subscriberContext ( this . context ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . contextAttributesMapper ) . apply ( any ( ) ) ;
@ -241,21 +257,18 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -241,21 +257,18 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
when ( this . authorizedClientProvider . authorize ( any ( OAuth2AuthorizationContext . class ) ) ) . thenReturn ( Mono . just ( this . authorizedClient ) ) ;
// Set custom contextAttributesMapper capable of mapping the form parameters
this . authorizedClientManager . setContextAttributesMapper ( authorizeRequest - > {
ServerWebExchange serverWebExchange = authorizeRequest . getAttribute ( ServerWebExchange . class . getName ( ) ) ;
return Mono . just ( serverWebExchange )
this . authorizedClientManager . setContextAttributesMapper ( authorizeRequest - >
currentServerWebExchange ( )
. flatMap ( ServerWebExchange : : getFormData )
. map ( formData - > {
Map < String , Object > contextAttributes = new HashMap < > ( ) ;
String username = formData . getFirst ( OAuth2ParameterNames . USERNAME ) ;
contextAttributes . put ( OAuth2AuthorizationContext . USERNAME_ATTRIBUTE_NAME , username ) ;
String password = formData . getFirst ( OAuth2ParameterNames . PASSWORD ) ;
if ( StringUtils . hasText ( username ) & & StringUtils . hasText ( password ) ) {
contextAttributes . put ( OAuth2AuthorizationContext . USERNAME_ATTRIBUTE_NAME , username ) ;
contextAttributes . put ( OAuth2AuthorizationContext . PASSWORD_ATTRIBUTE_NAME , password ) ;
}
contextAttributes . put ( OAuth2AuthorizationContext . PASSWORD_ATTRIBUTE_NAME , password ) ;
return contextAttributes ;
} ) ;
} ) ;
} )
) ;
this . serverWebExchange = MockServerWebExchange . builder (
MockServerHttpRequest
@ -263,12 +276,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -263,12 +276,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
. contentType ( MediaType . APPLICATION_FORM_URLENCODED )
. body ( "username=username&password=password" ) )
. build ( ) ;
this . context = Context . of ( ServerWebExchange . class , this . serverWebExchange ) ;
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest . withClientRegistrationId ( this . clientRegistration . getRegistrationId ( ) )
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
this . authorizedClientManager . authorize ( authorizeRequest ) . block ( ) ;
this . authorizedClientManager . authorize ( authorizeRequest ) . subscriberContext ( this . context ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
@ -284,9 +297,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -284,9 +297,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized ( ) {
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest . withAuthorizedClient ( this . authorizedClient )
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( reauthorizeRequest ) . block ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( reauthorizeRequest )
. subscriberContext ( this . context ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . contextAttributesMapper ) . apply ( eq ( reauthorizeRequest ) ) ;
@ -297,8 +310,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -297,8 +310,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
assertThat ( authorizationContext . getPrincipal ( ) ) . isEqualTo ( this . principal ) ;
assertThat ( authorizedClient ) . isSameAs ( this . authorizedClient ) ;
verify ( this . authorizedClientRepository , never ( ) ) . saveAuthorizedClient (
any ( OAuth2AuthorizedClient . class ) , eq ( this . principal ) , eq ( this . serverWebExchange ) ) ;
verify ( this . authorizedClientRepository , never ( ) ) . saveAuthorizedClient ( any ( ) , any ( ) , any ( ) ) ;
}
@SuppressWarnings ( "unchecked" )
@ -312,9 +324,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -312,9 +324,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest . withAuthorizedClient ( this . authorizedClient )
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( reauthorizeRequest ) . block ( ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientManager . authorize ( reauthorizeRequest )
. subscriberContext ( this . context ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
verify ( this . contextAttributesMapper ) . apply ( eq ( reauthorizeRequest ) ) ;
@ -346,12 +358,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -346,12 +358,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
. get ( "/" )
. queryParam ( OAuth2ParameterNames . SCOPE , "read write" ) )
. build ( ) ;
this . context = Context . of ( ServerWebExchange . class , this . serverWebExchange ) ;
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest . withAuthorizedClient ( this . authorizedClient )
. principal ( this . principal )
. attribute ( ServerWebExchange . class . getName ( ) , this . serverWebExchange )
. build ( ) ;
this . authorizedClientManager . authorize ( reauthorizeRequest ) . block ( ) ;
this . authorizedClientManager . authorize ( reauthorizeRequest ) . subscriberContext ( this . context ) . block ( ) ;
verify ( this . authorizedClientProvider ) . authorize ( this . authorizationContextCaptor . capture ( ) ) ;
@ -359,4 +371,10 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
@@ -359,4 +371,10 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
String [ ] requestScopeAttribute = authorizationContext . getAttribute ( OAuth2AuthorizationContext . REQUEST_SCOPE_ATTRIBUTE_NAME ) ;
assertThat ( requestScopeAttribute ) . contains ( "read" , "write" ) ;
}
private Mono < ServerWebExchange > currentServerWebExchange ( ) {
return Mono . subscriberContext ( )
. filter ( c - > c . hasKey ( ServerWebExchange . class ) )
. map ( c - > c . get ( ServerWebExchange . class ) ) ;
}
}