@ -16,6 +16,9 @@
@@ -16,6 +16,9 @@
package org.springframework.security.oauth2.client.web.reactive.function.client ;
import org.reactivestreams.Subscription ;
import org.springframework.beans.factory.DisposableBean ;
import org.springframework.beans.factory.InitializingBean ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpMethod ;
import org.springframework.http.MediaType ;
@ -44,8 +47,12 @@ import org.springframework.web.reactive.function.client.ClientResponse;
@@ -44,8 +47,12 @@ import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction ;
import org.springframework.web.reactive.function.client.ExchangeFunction ;
import org.springframework.web.reactive.function.client.WebClient ;
import reactor.core.CoreSubscriber ;
import reactor.core.publisher.Hooks ;
import reactor.core.publisher.Mono ;
import reactor.core.publisher.Operators ;
import reactor.core.scheduler.Schedulers ;
import reactor.util.context.Context ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
@ -98,7 +105,9 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
@@ -98,7 +105,9 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
* @author Rob Winch
* @since 5 . 1
* /
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
implements ExchangeFilterFunction , InitializingBean , DisposableBean {
/ * *
* The request attribute name used to locate the { @link OAuth2AuthorizedClient } .
* /
@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
@@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest . class . getName ( ) ;
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse . class . getName ( ) ;
private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber . class . getName ( ) ;
private Clock clock = Clock . systemUTC ( ) ;
private Duration accessTokenExpiresSkew = Duration . ofMinutes ( 1 ) ;
@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
@@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private String defaultClientRegistrationId ;
public ServletOAuth2AuthorizedClientExchangeFilterFunction ( ) { }
public ServletOAuth2AuthorizedClientExchangeFilterFunction ( ) {
}
public ServletOAuth2AuthorizedClientExchangeFilterFunction (
ClientRegistrationRepository clientRegistrationRepository ,
@ -132,6 +144,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
@@ -132,6 +144,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
this . authorizedClientRepository = authorizedClientRepository ;
}
@Override
public void afterPropertiesSet ( ) throws Exception {
Hooks . onLastOperator ( REQUEST_CONTEXT_OPERATOR_KEY , Operators . lift ( ( s , sub ) - > createRequestContextSubscriber ( sub ) ) ) ;
}
@Override
public void destroy ( ) throws Exception {
Hooks . resetOnLastOperator ( REQUEST_CONTEXT_OPERATOR_KEY ) ;
}
/ * *
* Sets the { @link OAuth2AccessTokenResponseClient } to be used for getting an { @link OAuth2AuthorizedClient } for
* client_credentials grant .
@ -266,15 +288,36 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
@@ -266,15 +288,36 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
@Override
public Mono < ClientResponse > filter ( ClientRequest request , ExchangeFunction next ) {
Optional < OAuth2AuthorizedClient > attribute = request . attribute ( OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME )
. map ( OAuth2AuthorizedClient . class : : cast ) ;
return Mono . justOrEmpty ( attribute )
. flatMap ( authorizedClient - > authorizedClient ( request , next , authorizedClient ) )
return Mono . just ( request )
. filter ( req - > req . attribute ( OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME ) . isPresent ( ) )
. switchIfEmpty ( mergeRequestAttributesFromContext ( request ) )
. filter ( req - > req . attribute ( OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME ) . isPresent ( ) )
. flatMap ( req - > authorizedClient ( req , next , getOAuth2AuthorizedClient ( req . attributes ( ) ) ) )
. map ( authorizedClient - > bearer ( request , authorizedClient ) )
. flatMap ( next : : exchange )
. switchIfEmpty ( next . exchange ( request ) ) ;
}
private Mono < ClientRequest > mergeRequestAttributesFromContext ( ClientRequest request ) {
return Mono . just ( ClientRequest . from ( request ) )
. flatMap ( builder - > Mono . subscriberContext ( )
. map ( ctx - > builder . attributes ( attrs - > populateRequestAttributes ( attrs , ctx ) ) ) )
. map ( ClientRequest . Builder : : build ) ;
}
private void populateRequestAttributes ( Map < String , Object > attrs , Context ctx ) {
if ( ctx . hasKey ( HTTP_SERVLET_REQUEST_ATTR_NAME ) ) {
attrs . putIfAbsent ( HTTP_SERVLET_REQUEST_ATTR_NAME , ctx . get ( HTTP_SERVLET_REQUEST_ATTR_NAME ) ) ;
}
if ( ctx . hasKey ( HTTP_SERVLET_RESPONSE_ATTR_NAME ) ) {
attrs . putIfAbsent ( HTTP_SERVLET_RESPONSE_ATTR_NAME , ctx . get ( HTTP_SERVLET_RESPONSE_ATTR_NAME ) ) ;
}
if ( ctx . hasKey ( AUTHENTICATION_ATTR_NAME ) ) {
attrs . putIfAbsent ( AUTHENTICATION_ATTR_NAME , ctx . get ( AUTHENTICATION_ATTR_NAME ) ) ;
}
populateDefaultOAuth2AuthorizedClient ( attrs ) ;
}
private void populateDefaultRequestResponse ( Map < String , Object > attrs ) {
if ( attrs . containsKey ( HTTP_SERVLET_REQUEST_ATTR_NAME ) & & attrs . containsKey (
HTTP_SERVLET_RESPONSE_ATTR_NAME ) ) {
@ -425,6 +468,19 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
@@ -425,6 +468,19 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
. build ( ) ;
}
private < T > CoreSubscriber < T > createRequestContextSubscriber ( CoreSubscriber < T > delegate ) {
HttpServletRequest request = null ;
HttpServletResponse response = null ;
ServletRequestAttributes requestAttributes =
( ServletRequestAttributes ) RequestContextHolder . getRequestAttributes ( ) ;
if ( requestAttributes ! = null ) {
request = requestAttributes . getRequest ( ) ;
response = requestAttributes . getResponse ( ) ;
}
Authentication authentication = SecurityContextHolder . getContext ( ) . getAuthentication ( ) ;
return new RequestContextSubscriber < > ( delegate , request , response , authentication ) ;
}
private static BodyInserters . FormInserter < String > refreshTokenBody ( String refreshToken ) {
return BodyInserters
. fromFormData ( "grant_type" , AuthorizationGrantType . REFRESH_TOKEN . getValue ( ) )
@ -498,4 +554,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
@@ -498,4 +554,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
return new UnsupportedOperationException ( "Not Supported" ) ;
}
}
private static class RequestContextSubscriber < T > implements CoreSubscriber < T > {
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber . class . getName ( ) . concat ( ".CONTEXT_DEFAULTED_ATTR_NAME" ) ;
private final CoreSubscriber < T > delegate ;
private final HttpServletRequest request ;
private final HttpServletResponse response ;
private final Authentication authentication ;
private RequestContextSubscriber ( CoreSubscriber < T > delegate ,
HttpServletRequest request ,
HttpServletResponse response ,
Authentication authentication ) {
this . delegate = delegate ;
this . request = request ;
this . response = response ;
this . authentication = authentication ;
}
@Override
public Context currentContext ( ) {
Context context = this . delegate . currentContext ( ) ;
if ( context . hasKey ( CONTEXT_DEFAULTED_ATTR_NAME ) ) {
return context ;
}
return Context . of (
CONTEXT_DEFAULTED_ATTR_NAME , Boolean . TRUE ,
HTTP_SERVLET_REQUEST_ATTR_NAME , this . request ,
HTTP_SERVLET_RESPONSE_ATTR_NAME , this . response ,
AUTHENTICATION_ATTR_NAME , this . authentication ) ;
}
@Override
public void onSubscribe ( Subscription s ) {
this . delegate . onSubscribe ( s ) ;
}
@Override
public void onNext ( T t ) {
this . delegate . onNext ( t ) ;
}
@Override
public void onError ( Throwable t ) {
this . delegate . onError ( t ) ;
}
@Override
public void onComplete ( ) {
this . delegate . onComplete ( ) ;
}
}
}