diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java index efc22ae7c6..ebbe9c8ec3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java @@ -27,6 +27,7 @@ import reactor.core.publisher.Mono; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; @@ -35,6 +36,7 @@ import org.springframework.security.oauth2.core.web.reactive.function.OAuth2Body import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.WebClient; @@ -68,6 +70,9 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient headersConverter = this::populateTokenRequestHeaders; + private BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors + .oauth2AccessTokenResponse(); + AbstractWebClientReactiveOAuth2AccessTokenResponseClient() { } @@ -204,7 +209,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient readTokenResponse(T grantRequest, ClientResponse response) { - return response.body(OAuth2BodyExtractors.oauth2AccessTokenResponse()) + return response.body(this.bodyExtractor) .map((tokenResponse) -> populateTokenResponse(grantRequest, tokenResponse)); } @@ -291,4 +296,17 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient, ReactiveHttpInputMessage> bodyExtractor) { + Assert.notNull(bodyExtractor, "bodyExtractor cannot be null"); + this.bodyExtractor = bodyExtractor; + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java index 37d3768a8c..34640d9a35 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java @@ -27,11 +27,13 @@ import okhttp3.mockwebserver.RecordedRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -41,11 +43,14 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExch import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.reactive.function.client.WebClient; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -409,4 +414,25 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests { .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); } + // gh-10260 + @Test + public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { + + String accessTokenSuccessResponse = "{}"; + + WebClientReactiveAuthorizationCodeTokenResponseClient customClient = new WebClientReactiveAuthorizationCodeTokenResponseClient(); + + BodyExtractor, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class); + OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(extractor.extract(any(), any())).willReturn(Mono.just(response)); + + customClient.setBodyExtractor(extractor); + + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(authorizationCodeGrantRequest()) + .block(); + assertThat(accessTokenResponse.getAccessToken()).isNotNull(); + + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java index 28acf6bdfc..a2a3fc0f7f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java @@ -27,21 +27,26 @@ import okhttp3.mockwebserver.RecordedRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -280,4 +285,26 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests { .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); } + // gh-10260 + @Test + public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { + + enqueueJson("{}"); + + WebClientReactiveClientCredentialsTokenResponseClient customClient = new WebClientReactiveClientCredentialsTokenResponseClient(); + + BodyExtractor, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class); + OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(extractor.extract(any(), any())).willReturn(Mono.just(response)); + + customClient.setBodyExtractor(extractor); + + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration.build()); + + OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(request).block(); + assertThat(accessTokenResponse.getAccessToken()).isNotNull(); + + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java index aabd789bc4..5b0118ebf9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java @@ -25,21 +25,26 @@ import okhttp3.mockwebserver.RecordedRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.web.reactive.function.BodyExtractor; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -286,4 +291,29 @@ public class WebClientReactivePasswordTokenResponseClientTests { .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); } + // gh-10260 + @Test + public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { + + String accessTokenSuccessResponse = "{}"; + + WebClientReactivePasswordTokenResponseClient customClient = new WebClientReactivePasswordTokenResponseClient(); + + BodyExtractor, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class); + OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(extractor.extract(any(), any())).willReturn(Mono.just(response)); + + customClient.setBodyExtractor(extractor); + + ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, + this.username, this.password); + + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(passwordGrantRequest).block(); + assertThat(accessTokenResponse.getAccessToken()).isNotNull(); + + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java index e94e1124ce..01f36f0fb8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java @@ -25,11 +25,13 @@ import okhttp3.mockwebserver.RecordedRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -39,10 +41,13 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.web.reactive.function.BodyExtractor; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -289,4 +294,27 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests { .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); } + // gh-10260 + @Test + public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { + + String accessTokenSuccessResponse = "{}"; + + WebClientReactiveRefreshTokenTokenResponseClient customClient = new WebClientReactiveRefreshTokenTokenResponseClient(); + + BodyExtractor, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class); + OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(extractor.extract(any(), any())).willReturn(Mono.just(response)); + + customClient.setBodyExtractor(extractor); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(refreshTokenGrantRequest).block(); + assertThat(accessTokenResponse.getAccessToken()).isNotNull(); + + } + }