|
|
|
@ -16,6 +16,9 @@ |
|
|
|
|
|
|
|
|
|
|
|
package org.springframework.security.oauth2.client.web.reactive.function.client; |
|
|
|
package org.springframework.security.oauth2.client.web.reactive.function.client; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import org.reactivestreams.Subscription; |
|
|
|
|
|
|
|
import org.springframework.beans.factory.DisposableBean; |
|
|
|
|
|
|
|
import org.springframework.beans.factory.InitializingBean; |
|
|
|
import org.springframework.http.HttpHeaders; |
|
|
|
import org.springframework.http.HttpHeaders; |
|
|
|
import org.springframework.http.HttpMethod; |
|
|
|
import org.springframework.http.HttpMethod; |
|
|
|
import org.springframework.http.MediaType; |
|
|
|
import org.springframework.http.MediaType; |
|
|
|
@ -44,8 +47,12 @@ import org.springframework.web.reactive.function.client.ClientResponse; |
|
|
|
import org.springframework.web.reactive.function.client.ExchangeFilterFunction; |
|
|
|
import org.springframework.web.reactive.function.client.ExchangeFilterFunction; |
|
|
|
import org.springframework.web.reactive.function.client.ExchangeFunction; |
|
|
|
import org.springframework.web.reactive.function.client.ExchangeFunction; |
|
|
|
import org.springframework.web.reactive.function.client.WebClient; |
|
|
|
import org.springframework.web.reactive.function.client.WebClient; |
|
|
|
|
|
|
|
import reactor.core.CoreSubscriber; |
|
|
|
|
|
|
|
import reactor.core.publisher.Hooks; |
|
|
|
import reactor.core.publisher.Mono; |
|
|
|
import reactor.core.publisher.Mono; |
|
|
|
|
|
|
|
import reactor.core.publisher.Operators; |
|
|
|
import reactor.core.scheduler.Schedulers; |
|
|
|
import reactor.core.scheduler.Schedulers; |
|
|
|
|
|
|
|
import reactor.util.context.Context; |
|
|
|
|
|
|
|
|
|
|
|
import javax.servlet.http.HttpServletRequest; |
|
|
|
import javax.servlet.http.HttpServletRequest; |
|
|
|
import javax.servlet.http.HttpServletResponse; |
|
|
|
import javax.servlet.http.HttpServletResponse; |
|
|
|
@ -98,7 +105,9 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu |
|
|
|
* @author Rob Winch |
|
|
|
* @author Rob Winch |
|
|
|
* @since 5.1 |
|
|
|
* @since 5.1 |
|
|
|
*/ |
|
|
|
*/ |
|
|
|
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { |
|
|
|
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction |
|
|
|
|
|
|
|
implements ExchangeFilterFunction, InitializingBean, DisposableBean { |
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
/** |
|
|
|
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}. |
|
|
|
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}. |
|
|
|
*/ |
|
|
|
*/ |
|
|
|
@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement |
|
|
|
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); |
|
|
|
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); |
|
|
|
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); |
|
|
|
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName(); |
|
|
|
|
|
|
|
|
|
|
|
private Clock clock = Clock.systemUTC(); |
|
|
|
private Clock clock = Clock.systemUTC(); |
|
|
|
|
|
|
|
|
|
|
|
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); |
|
|
|
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); |
|
|
|
@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement |
|
|
|
|
|
|
|
|
|
|
|
private String defaultClientRegistrationId; |
|
|
|
private String defaultClientRegistrationId; |
|
|
|
|
|
|
|
|
|
|
|
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {} |
|
|
|
public ServletOAuth2AuthorizedClientExchangeFilterFunction() { |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public ServletOAuth2AuthorizedClientExchangeFilterFunction( |
|
|
|
public ServletOAuth2AuthorizedClientExchangeFilterFunction( |
|
|
|
ClientRegistrationRepository clientRegistrationRepository, |
|
|
|
ClientRegistrationRepository clientRegistrationRepository, |
|
|
|
@ -132,6 +144,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement |
|
|
|
this.authorizedClientRepository = authorizedClientRepository; |
|
|
|
this.authorizedClientRepository = authorizedClientRepository; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public void afterPropertiesSet() throws Exception { |
|
|
|
|
|
|
|
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub))); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public void destroy() throws Exception { |
|
|
|
|
|
|
|
Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
/** |
|
|
|
* Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for |
|
|
|
* Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for |
|
|
|
* client_credentials grant. |
|
|
|
* client_credentials grant. |
|
|
|
@ -266,15 +288,36 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement |
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
@Override |
|
|
|
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) { |
|
|
|
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) { |
|
|
|
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME) |
|
|
|
return Mono.just(request) |
|
|
|
.map(OAuth2AuthorizedClient.class::cast); |
|
|
|
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) |
|
|
|
return Mono.justOrEmpty(attribute) |
|
|
|
.switchIfEmpty(mergeRequestAttributesFromContext(request)) |
|
|
|
.flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient)) |
|
|
|
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) |
|
|
|
|
|
|
|
.flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes()))) |
|
|
|
.map(authorizedClient -> bearer(request, authorizedClient)) |
|
|
|
.map(authorizedClient -> bearer(request, authorizedClient)) |
|
|
|
.flatMap(next::exchange) |
|
|
|
.flatMap(next::exchange) |
|
|
|
.switchIfEmpty(next.exchange(request)); |
|
|
|
.switchIfEmpty(next.exchange(request)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) { |
|
|
|
|
|
|
|
return Mono.just(ClientRequest.from(request)) |
|
|
|
|
|
|
|
.flatMap(builder -> Mono.subscriberContext() |
|
|
|
|
|
|
|
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx)))) |
|
|
|
|
|
|
|
.map(ClientRequest.Builder::build); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) { |
|
|
|
|
|
|
|
if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) { |
|
|
|
|
|
|
|
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME)); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) { |
|
|
|
|
|
|
|
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME)); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) { |
|
|
|
|
|
|
|
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME)); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
populateDefaultOAuth2AuthorizedClient(attrs); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private void populateDefaultRequestResponse(Map<String, Object> attrs) { |
|
|
|
private void populateDefaultRequestResponse(Map<String, Object> attrs) { |
|
|
|
if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey( |
|
|
|
if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey( |
|
|
|
HTTP_SERVLET_RESPONSE_ATTR_NAME)) { |
|
|
|
HTTP_SERVLET_RESPONSE_ATTR_NAME)) { |
|
|
|
@ -425,6 +468,19 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement |
|
|
|
.build(); |
|
|
|
.build(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) { |
|
|
|
|
|
|
|
HttpServletRequest request = null; |
|
|
|
|
|
|
|
HttpServletResponse response = null; |
|
|
|
|
|
|
|
ServletRequestAttributes requestAttributes = |
|
|
|
|
|
|
|
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); |
|
|
|
|
|
|
|
if (requestAttributes != null) { |
|
|
|
|
|
|
|
request = requestAttributes.getRequest(); |
|
|
|
|
|
|
|
response = requestAttributes.getResponse(); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); |
|
|
|
|
|
|
|
return new RequestContextSubscriber<>(delegate, request, response, authentication); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) { |
|
|
|
private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) { |
|
|
|
return BodyInserters |
|
|
|
return BodyInserters |
|
|
|
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()) |
|
|
|
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()) |
|
|
|
@ -498,4 +554,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement |
|
|
|
return new UnsupportedOperationException("Not Supported"); |
|
|
|
return new UnsupportedOperationException("Not Supported"); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> { |
|
|
|
|
|
|
|
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME"); |
|
|
|
|
|
|
|
private final CoreSubscriber<T> delegate; |
|
|
|
|
|
|
|
private final HttpServletRequest request; |
|
|
|
|
|
|
|
private final HttpServletResponse response; |
|
|
|
|
|
|
|
private final Authentication authentication; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private RequestContextSubscriber(CoreSubscriber<T> delegate, |
|
|
|
|
|
|
|
HttpServletRequest request, |
|
|
|
|
|
|
|
HttpServletResponse response, |
|
|
|
|
|
|
|
Authentication authentication) { |
|
|
|
|
|
|
|
this.delegate = delegate; |
|
|
|
|
|
|
|
this.request = request; |
|
|
|
|
|
|
|
this.response = response; |
|
|
|
|
|
|
|
this.authentication = authentication; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public Context currentContext() { |
|
|
|
|
|
|
|
Context context = this.delegate.currentContext(); |
|
|
|
|
|
|
|
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) { |
|
|
|
|
|
|
|
return context; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
return Context.of( |
|
|
|
|
|
|
|
CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE, |
|
|
|
|
|
|
|
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request, |
|
|
|
|
|
|
|
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response, |
|
|
|
|
|
|
|
AUTHENTICATION_ATTR_NAME, this.authentication); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public void onSubscribe(Subscription s) { |
|
|
|
|
|
|
|
this.delegate.onSubscribe(s); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public void onNext(T t) { |
|
|
|
|
|
|
|
this.delegate.onNext(t); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public void onError(Throwable t) { |
|
|
|
|
|
|
|
this.delegate.onError(t); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
|
|
public void onComplete() { |
|
|
|
|
|
|
|
this.delegate.onComplete(); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|