diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java index 2a3ab05d77..d34207c148 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java @@ -21,6 +21,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.BodyInserters; @@ -112,4 +113,9 @@ public class WebClientReactiveClientCredentialsTokenResponseClient implements Re } return body; } + + public void setWebClient(WebClient webClient) { + Assert.notNull(webClient, "webClient cannot be null"); + this.webClient = webClient; + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java index 6c227b336c..ee19f2dd97 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java @@ -28,9 +28,11 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.*; /** * @author Rob Winch @@ -55,6 +57,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests { @After public void cleanup() throws Exception { + validateMockitoUsage(); this.server.shutdown(); } @@ -117,6 +120,31 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests { assertThat(response.getAccessToken().getScopes()).isEqualTo(registration.getScopes()); } + @Test(expected=IllegalArgumentException.class) + public void setWebClientNullThenIllegalArgumentException(){ + client.setWebClient(null); + } + + @Test + public void setWebClientCustomThenCustomClientIsUsed() { + WebClient customClient = mock(WebClient.class); + when(customClient.post()).thenReturn(WebClient.builder().build().post()); + + this.client.setWebClient(customClient); + ClientRegistration registration = this.clientRegistration.build(); + enqueueJson("{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"); + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); + + OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); + + verify(customClient, atLeastOnce()).post(); + } + @Test(expected = WebClientResponseException.class) // gh-6089 public void getTokenResponseWhenInvalidResponse() throws WebClientResponseException {