diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolver.java
new file mode 100644
index 0000000000..4be413668a
--- /dev/null
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolver.java
@@ -0,0 +1,147 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.reactive.result.method.annotation;
+
+import org.springframework.core.MethodParameter;
+import org.springframework.core.annotation.AnnotatedElementUtils;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.annotation.OAuth2Client;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+import org.springframework.web.reactive.BindingContext;
+import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolver;
+import org.springframework.web.server.ServerWebExchange;
+
+import reactor.core.publisher.Mono;
+
+/**
+ * An implementation of a {@link HandlerMethodArgumentResolver} that is capable
+ * of resolving a method parameter into an argument value for the following types:
+ * {@link ClientRegistration}, {@link OAuth2AuthorizedClient} and {@link OAuth2AccessToken}.
+ *
+ *
+ * For example:
+ *
+ * @Controller
+ * public class MyController {
+ * @GetMapping("/client-registration")
+ * public Mono clientRegistration(@OAuth2Client("login-client") ClientRegistration clientRegistration) {
+ * // do something with clientRegistration
+ * }
+ *
+ * @GetMapping("/authorized-client")
+ * public Mono authorizedClient(@OAuth2Client("login-client") OAuth2AuthorizedClient authorizedClient) {
+ * // do something with authorizedClient
+ * }
+ *
+ * @GetMapping("/access-token")
+ * public Mono accessToken(@OAuth2Client("login-client") OAuth2AccessToken accessToken) {
+ * // do something with accessToken
+ * }
+ * }
+ *
+ *
+ * @author Rob Winch
+ * @since 5.1
+ * @see OAuth2Client
+ */
+public final class OAuth2ClientArgumentResolver implements HandlerMethodArgumentResolver {
+ private final ReactiveClientRegistrationRepository clientRegistrationRepository;
+ private final ReactiveOAuth2AuthorizedClientService authorizedClientService;
+
+ /**
+ * Constructs an {@code OAuth2ClientArgumentResolver} using the provided parameters.
+ *
+ * @param clientRegistrationRepository the repository of client registrations
+ * @param authorizedClientService the authorized client service
+ */
+ public OAuth2ClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository,
+ ReactiveOAuth2AuthorizedClientService authorizedClientService) {
+ Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+ Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
+ this.clientRegistrationRepository = clientRegistrationRepository;
+ this.authorizedClientService = authorizedClientService;
+ }
+
+ @Override
+ public boolean supportsParameter(MethodParameter parameter) {
+ return AnnotatedElementUtils.findMergedAnnotation(parameter.getParameter(), OAuth2Client.class) != null;
+ }
+
+ @Override
+ public Mono resolveArgument(
+ MethodParameter parameter, BindingContext bindingContext, ServerWebExchange exchange) {
+ return Mono.defer(() -> {
+ OAuth2Client oauth2ClientAnnotation = AnnotatedElementUtils
+ .findMergedAnnotation(parameter.getParameter(), OAuth2Client.class);
+
+ Mono clientRegistrationId = Mono.justOrEmpty(oauth2ClientAnnotation.registrationId())
+ .filter(id -> !StringUtils.isEmpty(id))
+ .switchIfEmpty(clientRegistrationId())
+ .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalArgumentException(
+ "Unable to resolve the Client Registration Identifier. It must be provided via @OAuth2Client(\"client1\") or @OAuth2Client(registrationId = \"client1\")."))));
+
+ if (ClientRegistration.class.isAssignableFrom(parameter.getParameterType())) {
+ return clientRegistrationId.flatMap(id -> this.clientRegistrationRepository.findByRegistrationId(id)
+ .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalArgumentException(
+ "Unable to find ClientRegistration with registration identifier \""
+ + id + "\"."))))).cast(Object.class);
+ }
+
+ Mono principalName = ReactiveSecurityContextHolder.getContext()
+ .map(SecurityContext::getAuthentication).map(Authentication::getName);
+
+ Mono authorizedClient = Mono
+ .zip(clientRegistrationId, principalName).switchIfEmpty(
+ clientRegistrationId.flatMap(id -> Mono.error(new IllegalStateException(
+ "Unable to resolve the Authorized Client with registration identifier \""
+ + id
+ + "\". An \"authenticated\" or \"unauthenticated\" session is required. To allow for unauthenticated access, ensure ServerHttpSecurity.anonymous() is configured."))))
+ .flatMap(zipped -> {
+ String registrationId = zipped.getT1();
+ String username = zipped.getT2();
+ return this.authorizedClientService
+ .loadAuthorizedClient(registrationId, username).switchIfEmpty(Mono.defer(() -> Mono
+ .error(new ClientAuthorizationRequiredException(
+ registrationId))));
+ }).cast(OAuth2AuthorizedClient.class);
+
+ if (OAuth2AccessToken.class.isAssignableFrom(parameter.getParameterType())) {
+ return authorizedClient.map(OAuth2AuthorizedClient::getAccessToken);
+ }
+
+ return authorizedClient.cast(Object.class);
+ });
+ }
+
+ private Mono clientRegistrationId() {
+ return ReactiveSecurityContextHolder.getContext()
+ .map(SecurityContext::getAuthentication)
+ .filter(authentication -> authentication instanceof OAuth2AuthenticationToken)
+ .cast(OAuth2AuthenticationToken.class)
+ .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
+ }
+}
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolverTests.java
new file mode 100644
index 0000000000..aedcbaae57
--- /dev/null
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolverTests.java
@@ -0,0 +1,267 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.reactive.result.method.annotation;
+
+import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
+import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.lang.reflect.Method;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+import org.springframework.core.MethodParameter;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.annotation.OAuth2Client;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.util.ReflectionUtils;
+
+import reactor.core.publisher.Hooks;
+import reactor.core.publisher.Mono;
+import reactor.util.context.Context;
+
+/**
+ * @author Rob Winch
+ * @since 5.1
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class OAuth2ClientArgumentResolverTests {
+ @Mock
+ private ReactiveClientRegistrationRepository clientRegistrationRepository;
+ @Mock
+ private ReactiveOAuth2AuthorizedClientService authorizedClientService;
+ private OAuth2ClientArgumentResolver argumentResolver;
+ private ClientRegistration clientRegistration;
+ private OAuth2AuthorizedClient authorizedClient;
+ private OAuth2AccessToken accessToken;
+
+ private Authentication authentication = new TestingAuthenticationToken("test", "this");
+
+ @Before
+ public void setUp() {
+ this.argumentResolver = new OAuth2ClientArgumentResolver(
+ this.clientRegistrationRepository, this.authorizedClientService);
+ this.clientRegistration = ClientRegistration.withRegistrationId("client1")
+ .clientId("client-id")
+ .clientSecret("secret")
+ .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+ .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+ .redirectUriTemplate("{baseUrl}/client1")
+ .scope("scope1", "scope2")
+ .authorizationUri("https://provider.com/oauth2/auth")
+ .tokenUri("https://provider.com/oauth2/token")
+ .clientName("Client 1")
+ .build();
+ when(this.clientRegistrationRepository.findByRegistrationId(anyString())).thenReturn(Mono.just(this.clientRegistration));
+ this.authorizedClient = mock(OAuth2AuthorizedClient.class);
+ when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(Mono.just(this.authorizedClient));
+ this.accessToken = mock(OAuth2AccessToken.class);
+ when(this.authorizedClient.getAccessToken()).thenReturn(this.accessToken);
+ Hooks.onOperatorDebug();
+ }
+
+ @Test
+ public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2ClientArgumentResolver(null, this.authorizedClientService))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
+ assertThatThrownBy(() -> new OAuth2ClientArgumentResolver(this.clientRegistrationRepository, null))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeOAuth2AccessTokenThenTrue() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessToken", OAuth2AccessToken.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeOAuth2AccessTokenWithoutAnnotationThenFalse() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessTokenWithoutAnnotation", OAuth2AccessToken.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientThenTrue() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientWithoutAnnotationThenFalse() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClientWithoutAnnotation", OAuth2AuthorizedClient.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeClientRegistrationThenTrue() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistration", ClientRegistration.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeClientRegistrationWithoutAnnotationThenFalse() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistrationWithoutAnnotation", ClientRegistration.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse();
+ }
+
+ @Test
+ public void supportsParameterWhenParameterTypeUnsupportedWithoutAnnotationThenFalse() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeUnsupportedWithoutAnnotation", String.class);
+ assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse();
+ }
+
+ @Test
+ public void resolveArgumentWhenRegistrationIdEmptyAndNotOAuth2AuthenticationThenThrowIllegalArgumentException() {
+ MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AccessToken.class);
+ assertThatThrownBy(() -> resolveArgument(methodParameter))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Unable to resolve the Client Registration Identifier. It must be provided via @OAuth2Client(\"client1\") or @OAuth2Client(registrationId = \"client1\").");
+ }
+
+ @Test
+ public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() {
+ this.authentication = mock(OAuth2AuthenticationToken.class);
+ when(this.authentication.getName()).thenReturn("client1");
+ when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1");
+ MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AccessToken.class);
+ resolveArgument(methodParameter);
+ }
+
+ @Test
+ public void resolveArgumentWhenClientRegistrationFoundThenResolves() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistration", ClientRegistration.class);
+ assertThat(resolveArgument(methodParameter)).isSameAs(this.clientRegistration);
+ }
+
+ @Test
+ public void resolveArgumentWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
+ when(this.clientRegistrationRepository.findByRegistrationId(anyString())).thenReturn(Mono.empty());
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistration", ClientRegistration.class);
+ assertThatThrownBy(() -> resolveArgument(methodParameter))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Unable to find ClientRegistration with registration identifier \"client1\".");
+ }
+
+ @Test
+ public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenThrowIllegalStateException() {
+ this.authentication = null;
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
+ assertThatThrownBy(() -> resolveArgument(methodParameter))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage("Unable to resolve the Authorized Client with registration identifier \"client1\". " +
+ "An \"authenticated\" or \"unauthenticated\" session is required. " +
+ "To allow for unauthenticated access, ensure ServerHttpSecurity.anonymous() is configured.");
+ }
+
+ @Test
+ public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
+ assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient);
+ }
+
+ @Test
+ public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() {
+ when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(Mono.empty());
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
+ assertThatThrownBy(() -> resolveArgument(methodParameter))
+ .isInstanceOf(ClientAuthorizationRequiredException.class);
+ }
+
+ @Test
+ public void resolveArgumentWhenOAuth2AccessTokenAndOAuth2AuthorizedClientFoundThenResolves() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessToken", OAuth2AccessToken.class);
+ assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient.getAccessToken());
+ }
+
+ @Test
+ public void resolveArgumentWhenOAuth2AccessTokenAndOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() {
+ when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(Mono.empty());
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessToken", OAuth2AccessToken.class);
+ assertThatThrownBy(() -> resolveArgument(methodParameter))
+ .isInstanceOf(ClientAuthorizationRequiredException.class);
+ }
+
+ @Test
+ public void resolveArgumentWhenOAuth2AccessTokenAndAnnotationRegistrationIdSetThenResolves() {
+ MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessTokenAnnotationRegistrationId", OAuth2AccessToken.class);
+ assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient.getAccessToken());
+ }
+
+ private Object resolveArgument(MethodParameter methodParameter) {
+ return this.argumentResolver.resolveArgument(methodParameter, null, null)
+ .subscriberContext(this.authentication == null ? Context.empty() : ReactiveSecurityContextHolder.withAuthentication(this.authentication))
+ .block();
+ }
+
+ private MethodParameter getMethodParameter(String methodName, Class>... paramTypes) {
+ Method method = ReflectionUtils.findMethod(
+ TestController.class, methodName, paramTypes);
+ return new MethodParameter(method, 0);
+ }
+
+ static class TestController {
+ void paramTypeAccessToken(@OAuth2Client("client1") OAuth2AccessToken accessToken) {
+ }
+
+ void paramTypeAccessTokenWithoutAnnotation(OAuth2AccessToken accessToken) {
+ }
+
+ void paramTypeAuthorizedClient(@OAuth2Client("client1") OAuth2AuthorizedClient authorizedClient) {
+ }
+
+ void paramTypeAuthorizedClientWithoutAnnotation(OAuth2AuthorizedClient authorizedClient) {
+ }
+
+ void paramTypeClientRegistration(@OAuth2Client("client1") ClientRegistration clientRegistration) {
+ }
+
+ void paramTypeClientRegistrationWithoutAnnotation(ClientRegistration clientRegistration) {
+ }
+
+ void paramTypeUnsupported(@OAuth2Client("client1") String param) {
+ }
+
+ void paramTypeUnsupportedWithoutAnnotation(String param) {
+ }
+
+ void registrationIdEmpty(@OAuth2Client OAuth2AccessToken accessToken) {
+ }
+
+ void paramTypeAccessTokenAnnotationRegistrationId(@OAuth2Client(registrationId = "client1") OAuth2AccessToken accessToken) {
+ }
+ }
+}