From de959dbff6d5acb14e8b20fa12589cdfd3d0055b Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Fri, 11 May 2018 00:36:20 -0500 Subject: [PATCH] Add OAuth2ClientArgumentResolver Issue: gh-4807 --- .../OAuth2ClientArgumentResolver.java | 147 ++++++++++ .../OAuth2ClientArgumentResolverTests.java | 267 ++++++++++++++++++ 2 files changed, 414 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolver.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolverTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolver.java new file mode 100644 index 0000000000..4be413668a --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolver.java @@ -0,0 +1,147 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.reactive.result.method.annotation; + +import org.springframework.core.MethodParameter; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.annotation.OAuth2Client; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.reactive.BindingContext; +import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolver; +import org.springframework.web.server.ServerWebExchange; + +import reactor.core.publisher.Mono; + +/** + * An implementation of a {@link HandlerMethodArgumentResolver} that is capable + * of resolving a method parameter into an argument value for the following types: + * {@link ClientRegistration}, {@link OAuth2AuthorizedClient} and {@link OAuth2AccessToken}. + * + *

+ * For example: + *

+ * @Controller
+ * public class MyController {
+ *     @GetMapping("/client-registration")
+ *     public Mono clientRegistration(@OAuth2Client("login-client") ClientRegistration clientRegistration) {
+ *         // do something with clientRegistration
+ *     }
+ *
+ *     @GetMapping("/authorized-client")
+ *     public Mono authorizedClient(@OAuth2Client("login-client") OAuth2AuthorizedClient authorizedClient) {
+ *         // do something with authorizedClient
+ *     }
+ *
+ *     @GetMapping("/access-token")
+ *     public Mono accessToken(@OAuth2Client("login-client") OAuth2AccessToken accessToken) {
+ *         // do something with accessToken
+ *     }
+ * }
+ * 
+ * + * @author Rob Winch + * @since 5.1 + * @see OAuth2Client + */ +public final class OAuth2ClientArgumentResolver implements HandlerMethodArgumentResolver { + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + private final ReactiveOAuth2AuthorizedClientService authorizedClientService; + + /** + * Constructs an {@code OAuth2ClientArgumentResolver} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientService the authorized client service + */ + public OAuth2ClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository, + ReactiveOAuth2AuthorizedClientService authorizedClientService) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientService = authorizedClientService; + } + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return AnnotatedElementUtils.findMergedAnnotation(parameter.getParameter(), OAuth2Client.class) != null; + } + + @Override + public Mono resolveArgument( + MethodParameter parameter, BindingContext bindingContext, ServerWebExchange exchange) { + return Mono.defer(() -> { + OAuth2Client oauth2ClientAnnotation = AnnotatedElementUtils + .findMergedAnnotation(parameter.getParameter(), OAuth2Client.class); + + Mono clientRegistrationId = Mono.justOrEmpty(oauth2ClientAnnotation.registrationId()) + .filter(id -> !StringUtils.isEmpty(id)) + .switchIfEmpty(clientRegistrationId()) + .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalArgumentException( + "Unable to resolve the Client Registration Identifier. It must be provided via @OAuth2Client(\"client1\") or @OAuth2Client(registrationId = \"client1\").")))); + + if (ClientRegistration.class.isAssignableFrom(parameter.getParameterType())) { + return clientRegistrationId.flatMap(id -> this.clientRegistrationRepository.findByRegistrationId(id) + .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalArgumentException( + "Unable to find ClientRegistration with registration identifier \"" + + id + "\"."))))).cast(Object.class); + } + + Mono principalName = ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication).map(Authentication::getName); + + Mono authorizedClient = Mono + .zip(clientRegistrationId, principalName).switchIfEmpty( + clientRegistrationId.flatMap(id -> Mono.error(new IllegalStateException( + "Unable to resolve the Authorized Client with registration identifier \"" + + id + + "\". An \"authenticated\" or \"unauthenticated\" session is required. To allow for unauthenticated access, ensure ServerHttpSecurity.anonymous() is configured.")))) + .flatMap(zipped -> { + String registrationId = zipped.getT1(); + String username = zipped.getT2(); + return this.authorizedClientService + .loadAuthorizedClient(registrationId, username).switchIfEmpty(Mono.defer(() -> Mono + .error(new ClientAuthorizationRequiredException( + registrationId)))); + }).cast(OAuth2AuthorizedClient.class); + + if (OAuth2AccessToken.class.isAssignableFrom(parameter.getParameterType())) { + return authorizedClient.map(OAuth2AuthorizedClient::getAccessToken); + } + + return authorizedClient.cast(Object.class); + }); + } + + private Mono clientRegistrationId() { + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .filter(authentication -> authentication instanceof OAuth2AuthenticationToken) + .cast(OAuth2AuthenticationToken.class) + .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolverTests.java new file mode 100644 index 0000000000..aedcbaae57 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2ClientArgumentResolverTests.java @@ -0,0 +1,267 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.reactive.result.method.annotation; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.lang.reflect.Method; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.core.MethodParameter; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.annotation.OAuth2Client; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.util.ReflectionUtils; + +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + +/** + * @author Rob Winch + * @since 5.1 + */ +@RunWith(MockitoJUnitRunner.class) +public class OAuth2ClientArgumentResolverTests { + @Mock + private ReactiveClientRegistrationRepository clientRegistrationRepository; + @Mock + private ReactiveOAuth2AuthorizedClientService authorizedClientService; + private OAuth2ClientArgumentResolver argumentResolver; + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private OAuth2AccessToken accessToken; + + private Authentication authentication = new TestingAuthenticationToken("test", "this"); + + @Before + public void setUp() { + this.argumentResolver = new OAuth2ClientArgumentResolver( + this.clientRegistrationRepository, this.authorizedClientService); + this.clientRegistration = ClientRegistration.withRegistrationId("client1") + .clientId("client-id") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("{baseUrl}/client1") + .scope("scope1", "scope2") + .authorizationUri("https://provider.com/oauth2/auth") + .tokenUri("https://provider.com/oauth2/token") + .clientName("Client 1") + .build(); + when(this.clientRegistrationRepository.findByRegistrationId(anyString())).thenReturn(Mono.just(this.clientRegistration)); + this.authorizedClient = mock(OAuth2AuthorizedClient.class); + when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(Mono.just(this.authorizedClient)); + this.accessToken = mock(OAuth2AccessToken.class); + when(this.authorizedClient.getAccessToken()).thenReturn(this.accessToken); + Hooks.onOperatorDebug(); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientArgumentResolver(null, this.authorizedClientService)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientArgumentResolver(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void supportsParameterWhenParameterTypeOAuth2AccessTokenThenTrue() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessToken", OAuth2AccessToken.class); + assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue(); + } + + @Test + public void supportsParameterWhenParameterTypeOAuth2AccessTokenWithoutAnnotationThenFalse() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessTokenWithoutAnnotation", OAuth2AccessToken.class); + assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse(); + } + + @Test + public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientThenTrue() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); + assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue(); + } + + @Test + public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientWithoutAnnotationThenFalse() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClientWithoutAnnotation", OAuth2AuthorizedClient.class); + assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse(); + } + + @Test + public void supportsParameterWhenParameterTypeClientRegistrationThenTrue() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistration", ClientRegistration.class); + assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue(); + } + + @Test + public void supportsParameterWhenParameterTypeClientRegistrationWithoutAnnotationThenFalse() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistrationWithoutAnnotation", ClientRegistration.class); + assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse(); + } + + @Test + public void supportsParameterWhenParameterTypeUnsupportedWithoutAnnotationThenFalse() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeUnsupportedWithoutAnnotation", String.class); + assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse(); + } + + @Test + public void resolveArgumentWhenRegistrationIdEmptyAndNotOAuth2AuthenticationThenThrowIllegalArgumentException() { + MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AccessToken.class); + assertThatThrownBy(() -> resolveArgument(methodParameter)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Unable to resolve the Client Registration Identifier. It must be provided via @OAuth2Client(\"client1\") or @OAuth2Client(registrationId = \"client1\")."); + } + + @Test + public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() { + this.authentication = mock(OAuth2AuthenticationToken.class); + when(this.authentication.getName()).thenReturn("client1"); + when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1"); + MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AccessToken.class); + resolveArgument(methodParameter); + } + + @Test + public void resolveArgumentWhenClientRegistrationFoundThenResolves() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistration", ClientRegistration.class); + assertThat(resolveArgument(methodParameter)).isSameAs(this.clientRegistration); + } + + @Test + public void resolveArgumentWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + when(this.clientRegistrationRepository.findByRegistrationId(anyString())).thenReturn(Mono.empty()); + MethodParameter methodParameter = this.getMethodParameter("paramTypeClientRegistration", ClientRegistration.class); + assertThatThrownBy(() -> resolveArgument(methodParameter)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Unable to find ClientRegistration with registration identifier \"client1\"."); + } + + @Test + public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenThrowIllegalStateException() { + this.authentication = null; + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); + assertThatThrownBy(() -> resolveArgument(methodParameter)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Unable to resolve the Authorized Client with registration identifier \"client1\". " + + "An \"authenticated\" or \"unauthenticated\" session is required. " + + "To allow for unauthenticated access, ensure ServerHttpSecurity.anonymous() is configured."); + } + + @Test + public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); + assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); + } + + @Test + public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() { + when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(Mono.empty()); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); + assertThatThrownBy(() -> resolveArgument(methodParameter)) + .isInstanceOf(ClientAuthorizationRequiredException.class); + } + + @Test + public void resolveArgumentWhenOAuth2AccessTokenAndOAuth2AuthorizedClientFoundThenResolves() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessToken", OAuth2AccessToken.class); + assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient.getAccessToken()); + } + + @Test + public void resolveArgumentWhenOAuth2AccessTokenAndOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() { + when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(Mono.empty()); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessToken", OAuth2AccessToken.class); + assertThatThrownBy(() -> resolveArgument(methodParameter)) + .isInstanceOf(ClientAuthorizationRequiredException.class); + } + + @Test + public void resolveArgumentWhenOAuth2AccessTokenAndAnnotationRegistrationIdSetThenResolves() { + MethodParameter methodParameter = this.getMethodParameter("paramTypeAccessTokenAnnotationRegistrationId", OAuth2AccessToken.class); + assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient.getAccessToken()); + } + + private Object resolveArgument(MethodParameter methodParameter) { + return this.argumentResolver.resolveArgument(methodParameter, null, null) + .subscriberContext(this.authentication == null ? Context.empty() : ReactiveSecurityContextHolder.withAuthentication(this.authentication)) + .block(); + } + + private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { + Method method = ReflectionUtils.findMethod( + TestController.class, methodName, paramTypes); + return new MethodParameter(method, 0); + } + + static class TestController { + void paramTypeAccessToken(@OAuth2Client("client1") OAuth2AccessToken accessToken) { + } + + void paramTypeAccessTokenWithoutAnnotation(OAuth2AccessToken accessToken) { + } + + void paramTypeAuthorizedClient(@OAuth2Client("client1") OAuth2AuthorizedClient authorizedClient) { + } + + void paramTypeAuthorizedClientWithoutAnnotation(OAuth2AuthorizedClient authorizedClient) { + } + + void paramTypeClientRegistration(@OAuth2Client("client1") ClientRegistration clientRegistration) { + } + + void paramTypeClientRegistrationWithoutAnnotation(ClientRegistration clientRegistration) { + } + + void paramTypeUnsupported(@OAuth2Client("client1") String param) { + } + + void paramTypeUnsupportedWithoutAnnotation(String param) { + } + + void registrationIdEmpty(@OAuth2Client OAuth2AccessToken accessToken) { + } + + void paramTypeAccessTokenAnnotationRegistrationId(@OAuth2Client(registrationId = "client1") OAuth2AccessToken accessToken) { + } + } +}