@ -15,6 +15,18 @@
@@ -15,6 +15,18 @@
* /
package org.springframework.security.oauth2.client.web.reactive.function.client ;
import java.net.URI ;
import java.time.Duration ;
import java.time.Instant ;
import java.util.ArrayList ;
import java.util.HashMap ;
import java.util.List ;
import java.util.Map ;
import java.util.Optional ;
import java.util.function.Consumer ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import org.junit.After ;
import org.junit.Before ;
import org.junit.Test ;
@ -23,6 +35,8 @@ import org.mockito.ArgumentCaptor;
@@ -23,6 +35,8 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor ;
import org.mockito.Mock ;
import org.mockito.junit.MockitoJUnitRunner ;
import reactor.util.context.Context ;
import org.springframework.core.codec.ByteBufferEncoder ;
import org.springframework.core.codec.CharSequenceEncoder ;
import org.springframework.http.HttpHeaders ;
@ -76,26 +90,26 @@ import org.springframework.web.context.request.ServletRequestAttributes;
@@ -76,26 +90,26 @@ import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.BodyInserter ;
import org.springframework.web.reactive.function.client.ClientRequest ;
import org.springframework.web.reactive.function.client.WebClient ;
import reactor.util.context.Context ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import java.net.URI ;
import java.time.Duration ;
import java.time.Instant ;
import java.util.ArrayList ;
import java.util.HashMap ;
import java.util.List ;
import java.util.Map ;
import java.util.Optional ;
import java.util.function.Consumer ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy ;
import static org.mockito.Mockito.* ;
import static org.mockito.Mockito.any ;
import static org.mockito.Mockito.eq ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.never ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.verifyZeroInteractions ;
import static org.mockito.Mockito.when ;
import static org.springframework.http.HttpMethod.GET ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.* ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse ;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient ;
/ * *
* @author Rob Winch
@ -603,8 +617,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@@ -603,8 +617,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
when ( this . authorizedClientRepository . loadAuthorizedClient ( eq ( authentication . getAuthorizedClientRegistrationId ( ) ) ,
eq ( authentication ) , eq ( servletRequest ) ) ) . thenReturn ( authorizedClient ) ;
when ( this . clientRegistrationRepository . findByRegistrationId ( eq ( authentication . getAuthorizedClientRegistrationId ( ) ) ) ) . thenReturn ( this . registration ) ;
// Default request attributes set
final ClientRequest request1 = ClientRequest . create ( GET , URI . create ( "https://example1.com" ) )
. attributes ( attrs - > attrs . putAll ( getDefaultRequestAttributes ( ) ) ) . build ( ) ;