@ -24,6 +24,7 @@ import java.util.Set;
import reactor.core.publisher.Mono ;
import reactor.core.publisher.Mono ;
import org.springframework.core.convert.converter.Converter ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.MediaType ;
import org.springframework.http.MediaType ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
import org.springframework.security.oauth2.client.registration.ClientRegistration ;
@ -65,6 +66,8 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
private WebClient webClient = WebClient . builder ( ) . build ( ) ;
private WebClient webClient = WebClient . builder ( ) . build ( ) ;
private Converter < T , HttpHeaders > headersConverter = this : : populateTokenRequestHeaders ;
AbstractWebClientReactiveOAuth2AccessTokenResponseClient ( ) {
AbstractWebClientReactiveOAuth2AccessTokenResponseClient ( ) {
}
}
@ -74,7 +77,12 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
// @formatter:off
// @formatter:off
return Mono . defer ( ( ) - > this . webClient . post ( )
return Mono . defer ( ( ) - > this . webClient . post ( )
. uri ( clientRegistration ( grantRequest ) . getProviderDetails ( ) . getTokenUri ( ) )
. uri ( clientRegistration ( grantRequest ) . getProviderDetails ( ) . getTokenUri ( ) )
. headers ( ( headers ) - > populateTokenRequestHeaders ( grantRequest , headers ) )
. headers ( ( headers ) - > {
HttpHeaders headersToAdd = getHeadersConverter ( ) . convert ( grantRequest ) ;
if ( headersToAdd ! = null ) {
headers . addAll ( headersToAdd ) ;
}
} )
. body ( createTokenRequestBody ( grantRequest ) )
. body ( createTokenRequestBody ( grantRequest ) )
. exchange ( )
. exchange ( )
. flatMap ( ( response ) - > readTokenResponse ( grantRequest , response ) )
. flatMap ( ( response ) - > readTokenResponse ( grantRequest , response ) )
@ -92,9 +100,10 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
/ * *
/ * *
* Populates the headers for the token request .
* Populates the headers for the token request .
* @param grantRequest the grant request
* @param grantRequest the grant request
* @param headers the headers to populate
* @return the headers populated for the token request
* /
* /
private void populateTokenRequestHeaders ( T grantRequest , HttpHeaders headers ) {
private HttpHeaders populateTokenRequestHeaders ( T grantRequest ) {
HttpHeaders headers = new HttpHeaders ( ) ;
ClientRegistration clientRegistration = clientRegistration ( grantRequest ) ;
ClientRegistration clientRegistration = clientRegistration ( grantRequest ) ;
headers . setContentType ( MediaType . APPLICATION_FORM_URLENCODED ) ;
headers . setContentType ( MediaType . APPLICATION_FORM_URLENCODED ) ;
headers . setAccept ( Collections . singletonList ( MediaType . APPLICATION_JSON ) ) ;
headers . setAccept ( Collections . singletonList ( MediaType . APPLICATION_JSON ) ) ;
@ -104,6 +113,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
String clientSecret = encodeClientCredential ( clientRegistration . getClientSecret ( ) ) ;
String clientSecret = encodeClientCredential ( clientRegistration . getClientSecret ( ) ) ;
headers . setBasicAuth ( clientId , clientSecret ) ;
headers . setBasicAuth ( clientId , clientSecret ) ;
}
}
return headers ;
}
}
private static String encodeClientCredential ( String clientCredential ) {
private static String encodeClientCredential ( String clientCredential ) {
@ -230,4 +240,55 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
this . webClient = webClient ;
this . webClient = webClient ;
}
}
/ * *
* Returns the { @link Converter } used for converting the
* { @link AbstractOAuth2AuthorizationGrantRequest } instance to a { @link HttpHeaders }
* used in the OAuth 2 . 0 Access Token Request headers .
* @return the { @link Converter } used for converting the
* { @link AbstractOAuth2AuthorizationGrantRequest } to { @link HttpHeaders }
* /
final Converter < T , HttpHeaders > getHeadersConverter ( ) {
return this . headersConverter ;
}
/ * *
* Sets the { @link Converter } used for converting the
* { @link AbstractOAuth2AuthorizationGrantRequest } instance to a { @link HttpHeaders }
* used in the OAuth 2 . 0 Access Token Request headers .
* @param headersConverter the { @link Converter } used for converting the
* { @link AbstractOAuth2AuthorizationGrantRequest } to { @link HttpHeaders }
* @since 5 . 6
* /
public final void setHeadersConverter ( Converter < T , HttpHeaders > headersConverter ) {
Assert . notNull ( headersConverter , "headersConverter cannot be null" ) ;
this . headersConverter = headersConverter ;
}
/ * *
* Add ( compose ) the provided { @code headersConverter } to the current
* { @link Converter } used for converting the
* { @link AbstractOAuth2AuthorizationGrantRequest } instance to a { @link HttpHeaders }
* used in the OAuth 2 . 0 Access Token Request headers .
* @param headersConverter the { @link Converter } to add ( compose ) to the current
* { @link Converter } used for converting the
* { @link AbstractOAuth2AuthorizationGrantRequest } to a { @link HttpHeaders }
* @since 5 . 6
* /
public final void addHeadersConverter ( Converter < T , HttpHeaders > headersConverter ) {
Assert . notNull ( headersConverter , "headersConverter cannot be null" ) ;
Converter < T , HttpHeaders > currentHeadersConverter = this . headersConverter ;
this . headersConverter = ( authorizationGrantRequest ) - > {
// Append headers using a Composite Converter
HttpHeaders headers = currentHeadersConverter . convert ( authorizationGrantRequest ) ;
if ( headers = = null ) {
headers = new HttpHeaders ( ) ;
}
HttpHeaders headersToAdd = headersConverter . convert ( authorizationGrantRequest ) ;
if ( headersToAdd ! = null ) {
headers . addAll ( headersToAdd ) ;
}
return headers ;
} ;
}
}
}