@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2018 the original author or authors .
* Copyright 2002 - 2019 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -28,16 +28,22 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
@@ -28,16 +28,22 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
import org.springframework.security.core.Authentication ;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken ;
import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthorizationCodeAuthenticationTokens ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository ;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse ;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests ;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses ;
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter ;
import org.springframework.web.server.handler.DefaultWebFilterChain ;
import reactor.core.publisher.Mono ;
import static org.assertj.core.api.Assertions.* ;
import static org.assertj.core.api.Assertions.assertThatCode ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.verifyZeroInteractions ;
import static org.mockito.Mockito.when ;
import static org.mockito.Mockito.* ;
/ * *
* @author Rob Winch
@ -53,6 +59,9 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
@@ -53,6 +59,9 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
@Mock
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository ;
private ServerAuthorizationRequestRepository < OAuth2AuthorizationRequest > authorizationRequestRepository =
new WebSessionOAuth2ServerAuthorizationRequestRepository ( ) ;
@Before
public void setup ( ) {
this . filter = new OAuth2AuthorizationCodeGrantWebFilter (
@ -101,25 +110,42 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
@@ -101,25 +110,42 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
@Test
public void filterWhenMatchThenAuthorizedClientSaved ( ) {
Mono < Authentication > authentication = Mono
. just ( TestOAuth2AuthorizationCodeAuthenticationTokens . unauthenticated ( ) ) ;
OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests . request ( )
. redirectUri ( "/authorize/registration-id" )
. build ( ) ;
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses . success ( )
. redirectUri ( "/authorize/registration-id" )
. build ( ) ;
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange ( authorizationRequest , authorizationResponse ) ;
ClientRegistration registration = TestClientRegistrations . clientRegistration ( ) . build ( ) ;
Mono < Authentication > authentication = Mono . just (
new OAuth2AuthorizationCodeAuthenticationToken ( registration , authorizationExchange ) ) ;
OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens
. authenticated ( ) ;
when ( this . authenticationManager . authenticate ( any ( ) ) ) . thenReturn (
Mono . just ( authenticated ) ) ;
when ( this . authorizedClientRepository . saveAuthorizedClient ( any ( ) , any ( ) , any ( ) ) )
. thenReturn ( Mono . empty ( ) ) ;
ServerAuthenticationConverter converter = e - > authentication ;
this . filter = new OAuth2AuthorizationCodeGrantWebFilter (
this . authenticationManager , converter , this . authorizedClientRepository ) ;
MockServerWebExchange exchange = MockServerWebExchange . from ( MockServerHttpRequest
. get ( "/authorize/oauth2/code/registration-id" ) ) ;
MockServerHttpRequest request = MockServerHttpRequest
. get ( "/authorize/registration-id" )
. queryParam ( OAuth2ParameterNames . CODE , "code" )
. queryParam ( OAuth2ParameterNames . STATE , "state" )
. build ( ) ;
MockServerWebExchange exchange = MockServerWebExchange . from ( request ) ;
DefaultWebFilterChain chain = new DefaultWebFilterChain (
e - > e . getResponse ( ) . setComplete ( ) ) ;
when ( this . authenticationManager . authenticate ( any ( ) ) ) . thenReturn ( Mono . just (
authenticated ) ) ;
when ( this . authorizedClientRepository . saveAuthorizedClient ( any ( ) , any ( ) , any ( ) ) )
. thenReturn ( Mono . empty ( ) ) ;
this . authorizationRequestRepository . saveAuthorizationRequest ( authorizationRequest , exchange ) . block ( ) ;
this . filter . filter ( exchange , chain ) . block ( ) ;
verify ( this . authorizedClientRepository ) . saveAuthorizedClient ( any ( ) , any ( AnonymousAuthenticationToken . class ) , any ( ) ) ;
}
}