From 8e615d0feed958c7bfd7aeb8cbb55ff2ce04f6db Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 27 Aug 2018 10:45:32 -0400 Subject: [PATCH] Re-factor DefaultClientCredentialsTokenResponseClient Fixes gh-5735 --- ...tClientCredentialsTokenResponseClient.java | 218 +++--------------- ...zationCodeGrantRequestEntityConverter.java | 27 +-- ...2AuthorizationGrantRequestEntityUtils.java | 59 +++++ ...redentialsGrantRequestEntityConverter.java | 87 +++++++ ...ntCredentialsTokenResponseClientTests.java | 39 ++-- ...tialsGrantRequestEntityConverterTests.java | 77 +++++++ 6 files changed, 276 insertions(+), 231 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java index f99409a276..c6fac6c7ad 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java @@ -15,41 +15,25 @@ */ package org.springframework.security.oauth2.client.endpoint; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; +import org.springframework.core.convert.converter.Converter; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; -import org.springframework.http.client.ClientHttpResponse; -import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; -import org.springframework.web.util.UriComponentsBuilder; -import java.io.IOException; -import java.net.URI; import java.util.Arrays; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; /** * The default implementation of an {@link OAuth2AccessTokenResponseClient} @@ -65,25 +49,18 @@ import java.util.stream.Stream; * @see Section 4.4.2 Access Token Request (Client Credentials Grant) * @see Section 4.4.3 Access Token Response (Client Credentials Grant) */ -public class DefaultClientCredentialsTokenResponseClient implements OAuth2AccessTokenResponseClient { - private static final String INVALID_TOKEN_REQUEST_ERROR_CODE = "invalid_token_request"; - +public final class DefaultClientCredentialsTokenResponseClient implements OAuth2AccessTokenResponseClient { private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; - private static final String[] TOKEN_RESPONSE_PARAMETER_NAMES = { - OAuth2ParameterNames.ACCESS_TOKEN, - OAuth2ParameterNames.TOKEN_TYPE, - OAuth2ParameterNames.EXPIRES_IN, - OAuth2ParameterNames.SCOPE, - OAuth2ParameterNames.REFRESH_TOKEN - }; + private Converter> requestEntityConverter = + new OAuth2ClientCredentialsGrantRequestEntityConverter(); private RestOperations restOperations; public DefaultClientCredentialsTokenResponseClient() { - RestTemplate restTemplate = new RestTemplate(); - // Disable the ResponseErrorHandler as errors are handled directly within this class - restTemplate.setErrorHandler(new NoOpResponseErrorHandler()); + RestTemplate restTemplate = new RestTemplate(Arrays.asList( + new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); + restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); this.restOperations = restTemplate; } @@ -93,50 +70,18 @@ public class DefaultClientCredentialsTokenResponseClient implements OAuth2Access Assert.notNull(clientCredentialsGrantRequest, "clientCredentialsGrantRequest cannot be null"); - // Build request - RequestEntity> request = this.buildRequest(clientCredentialsGrantRequest); + RequestEntity request = this.requestEntityConverter.convert(clientCredentialsGrantRequest); - // Exchange - ResponseEntity> response; + ResponseEntity response; try { - response = this.restOperations.exchange( - request, new ParameterizedTypeReference>() {}); - } catch (Exception ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_REQUEST_ERROR_CODE, - "An error occurred while sending the Access Token Request: " + ex.getMessage(), null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); - } - - Map responseParameters = response.getBody(); - - // Check for Error Response - if (response.getStatusCodeValue() != 200) { - OAuth2Error oauth2Error = this.parseErrorResponse(responseParameters); - if (oauth2Error == null) { - oauth2Error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR); - } - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - - // Success Response - OAuth2AccessTokenResponse tokenResponse; - try { - tokenResponse = this.parseTokenResponse(responseParameters); - } catch (Exception ex) { + response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } catch (RestClientException ex) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred parsing the Access Token response (200 OK): " + ex.getMessage(), null); + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); } - if (tokenResponse == null) { - // This should never happen as long as the provider - // implements a Successful Response as defined in Section 5.1 - // https://tools.ietf.org/html/rfc6749#section-5.1 - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred parsing the Access Token response (200 OK). " + - "Missing required parameters: access_token and/or token_type", null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } + OAuth2AccessTokenResponse tokenResponse = response.getBody(); if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { // As per spec, in Section 5.1 Successful Access Token Response @@ -151,120 +96,31 @@ public class DefaultClientCredentialsTokenResponseClient implements OAuth2Access return tokenResponse; } - private RequestEntity> buildRequest(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { - HttpHeaders headers = this.buildHeaders(clientCredentialsGrantRequest); - MultiValueMap formParameters = this.buildFormParameters(clientCredentialsGrantRequest); - URI uri = UriComponentsBuilder.fromUriString(clientCredentialsGrantRequest.getClientRegistration().getProviderDetails().getTokenUri()) - .build() - .toUri(); - - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); - } - - private HttpHeaders buildHeaders(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { - ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); - - HttpHeaders headers = new HttpHeaders(); - headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); - headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); - if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { - headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); - } - - return headers; - } - - private MultiValueMap buildFormParameters(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { - ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); - - MultiValueMap formParameters = new LinkedMultiValueMap<>(); - formParameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue()); - if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { - formParameters.add(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); - } - if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); - formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); - } - - return formParameters; - } - - private OAuth2Error parseErrorResponse(Map responseParameters) { - if (CollectionUtils.isEmpty(responseParameters) || - !responseParameters.containsKey(OAuth2ParameterNames.ERROR)) { - return null; - } - - String errorCode = responseParameters.get(OAuth2ParameterNames.ERROR); - String errorDescription = responseParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION); - String errorUri = responseParameters.get(OAuth2ParameterNames.ERROR_URI); - - return new OAuth2Error(errorCode, errorDescription, errorUri); - } - - private OAuth2AccessTokenResponse parseTokenResponse(Map responseParameters) { - if (CollectionUtils.isEmpty(responseParameters) || - !responseParameters.containsKey(OAuth2ParameterNames.ACCESS_TOKEN) || - !responseParameters.containsKey(OAuth2ParameterNames.TOKEN_TYPE)) { - return null; - } - - String accessToken = responseParameters.get(OAuth2ParameterNames.ACCESS_TOKEN); - - OAuth2AccessToken.TokenType accessTokenType = null; - if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( - responseParameters.get(OAuth2ParameterNames.TOKEN_TYPE))) { - accessTokenType = OAuth2AccessToken.TokenType.BEARER; - } - - long expiresIn = 0; - if (responseParameters.containsKey(OAuth2ParameterNames.EXPIRES_IN)) { - try { - expiresIn = Long.valueOf(responseParameters.get(OAuth2ParameterNames.EXPIRES_IN)); - } catch (NumberFormatException ex) { } - } - - Set scopes = Collections.emptySet(); - if (responseParameters.containsKey(OAuth2ParameterNames.SCOPE)) { - String scope = responseParameters.get(OAuth2ParameterNames.SCOPE); - scopes = Arrays.stream(StringUtils.delimitedListToStringArray(scope, " ")).collect(Collectors.toSet()); - } - - Map additionalParameters = new LinkedHashMap<>(); - Set tokenResponseParameterNames = Stream.of(TOKEN_RESPONSE_PARAMETER_NAMES).collect(Collectors.toSet()); - responseParameters.entrySet().stream() - .filter(e -> !tokenResponseParameterNames.contains(e.getKey())) - .forEach(e -> additionalParameters.put(e.getKey(), e.getValue())); - - return OAuth2AccessTokenResponse.withToken(accessToken) - .tokenType(accessTokenType) - .expiresIn(expiresIn) - .scopes(scopes) - .additionalParameters(additionalParameters) - .build(); + /** + * Sets the {@link Converter} used for converting the {@link OAuth2ClientCredentialsGrantRequest} + * to a {@link RequestEntity} representation of the OAuth 2.0 Access Token Request. + * + * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the Access Token Request + */ + public void setRequestEntityConverter(Converter> requestEntityConverter) { + Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); + this.requestEntityConverter = requestEntityConverter; } /** - * Sets the {@link RestOperations} used when requesting the access token response. + * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token Response. + * + *

+ * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + *

    + *
  1. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and {@link OAuth2AccessTokenResponseHttpMessageConverter}
  2. + *
  3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
  4. + *
* - * @param restOperations the {@link RestOperations} used when requesting the access token response + * @param restOperations the {@link RestOperations} used when requesting the Access Token Response */ - public final void setRestOperations(RestOperations restOperations) { + public void setRestOperations(RestOperations restOperations) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; } - - private static class NoOpResponseErrorHandler implements ResponseErrorHandler { - - @Override - public boolean hasError(ClientHttpResponse response) throws IOException { - return false; - } - - @Override - public void handleError(ClientHttpResponse response) throws IOException { - } - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java index db65cd9e95..c54045727d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java @@ -18,9 +18,7 @@ package org.springframework.security.oauth2.client.endpoint; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; -import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; @@ -30,9 +28,6 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.util.UriComponentsBuilder; import java.net.URI; -import java.util.Collections; - -import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; /** * A {@link Converter} that converts the provided {@link OAuth2AuthorizationCodeGrantRequest} @@ -57,7 +52,7 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverter implements Conve public RequestEntity convert(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration(); - HttpHeaders headers = this.buildHeaders(authorizationCodeGrantRequest); + HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); MultiValueMap formParameters = this.buildFormParameters(authorizationCodeGrantRequest); URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()) .build() @@ -66,26 +61,6 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverter implements Conve return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); } - /** - * Returns the {@link HttpHeaders} used for the Access Token Request. - * - * @param authorizationCodeGrantRequest the authorization code grant request - * @return the {@link HttpHeaders} used for the Access Token Request - */ - private HttpHeaders buildHeaders(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { - ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration(); - - HttpHeaders headers = new HttpHeaders(); - headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8)); - final MediaType contentType = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); - headers.setContentType(contentType); - if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { - headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); - } - - return headers; - } - /** * Returns a {@link MultiValueMap} of the form parameters used for the Access Token Request body. * diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java new file mode 100644 index 0000000000..a9b53c9ce5 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; + +import java.util.Collections; + +import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; + +/** + * Utility methods used by the {@link Converter}'s that convert + * from an implementation of an {@link AbstractOAuth2AuthorizationGrantRequest} + * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request + * for the specific Authorization Grant. + * + * @author Joe Grandja + * @since 5.1 + * @see OAuth2AuthorizationCodeGrantRequestEntityConverter + * @see OAuth2ClientCredentialsGrantRequestEntityConverter + */ +final class OAuth2AuthorizationGrantRequestEntityUtils { + private static HttpHeaders DEFAULT_TOKEN_REQUEST_HEADERS = getDefaultTokenRequestHeaders(); + + static HttpHeaders getTokenRequestHeaders(ClientRegistration clientRegistration) { + HttpHeaders headers = new HttpHeaders(); + headers.addAll(DEFAULT_TOKEN_REQUEST_HEADERS); + if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { + headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); + } + return headers; + } + + private static HttpHeaders getDefaultTokenRequestHeaders() { + HttpHeaders headers = new HttpHeaders(); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8)); + final MediaType contentType = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + headers.setContentType(contentType); + return headers; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java new file mode 100644 index 0000000000..9ee2d68d17 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; + +import java.net.URI; + +/** + * A {@link Converter} that converts the provided {@link OAuth2ClientCredentialsGrantRequest} + * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request + * for the Client Credentials Grant. + * + * @author Joe Grandja + * @since 5.1 + * @see Converter + * @see OAuth2ClientCredentialsGrantRequest + * @see RequestEntity + */ +public class OAuth2ClientCredentialsGrantRequestEntityConverter implements Converter> { + + /** + * Returns the {@link RequestEntity} used for the Access Token Request. + * + * @param clientCredentialsGrantRequest the client credentials grant request + * @return the {@link RequestEntity} used for the Access Token Request + */ + @Override + public RequestEntity convert(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { + ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); + + HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); + MultiValueMap formParameters = this.buildFormParameters(clientCredentialsGrantRequest); + URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()) + .build() + .toUri(); + + return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); + } + + /** + * Returns a {@link MultiValueMap} of the form parameters used for the Access Token Request body. + * + * @param clientCredentialsGrantRequest the client credentials grant request + * @return a {@link MultiValueMap} of the form parameters used for the Access Token Request body + */ + private MultiValueMap buildFormParameters(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { + ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); + + MultiValueMap formParameters = new LinkedMultiValueMap<>(); + formParameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue()); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + formParameters.add(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { + formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + + return formParameters; + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java index 117a17cb34..07a563d710 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java @@ -50,9 +50,7 @@ public class DefaultClientCredentialsTokenResponseClientTests { public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); - String tokenUri = this.server.url("/oauth2/token").toString(); - this.clientRegistration = ClientRegistration.withRegistrationId("registration-1") .clientId("client-1") .clientSecret("secret") @@ -68,6 +66,12 @@ public class DefaultClientCredentialsTokenResponseClientTests { this.server.shutdown(); } + @Test + public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.tokenResponseClient.setRestOperations(null)) @@ -103,8 +107,8 @@ public class DefaultClientCredentialsTokenResponseClientTests { RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); - assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON.toString()); - assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).startsWith(MediaType.APPLICATION_FORM_URLENCODED.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=client_credentials"); @@ -178,7 +182,8 @@ public class DefaultClientCredentialsTokenResponseClientTests { assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response (200 OK): tokenType cannot be null"); + .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .hasMessageContaining("tokenType cannot be null"); } @Test @@ -193,7 +198,8 @@ public class DefaultClientCredentialsTokenResponseClientTests { assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response (200 OK). Missing required parameters: access_token and/or token_type"); + .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .hasMessageContaining("tokenType cannot be null"); } @Test @@ -231,21 +237,6 @@ public class DefaultClientCredentialsTokenResponseClientTests { assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write"); } - @Test - public void getTokenResponseWhenTokenUriMalformedThenThrowOAuth2AuthenticationException() { - String malformedTokenUri = "http:\\provider.com\\oauth2\\token"; - ClientRegistration clientRegistration = this.from(this.clientRegistration) - .tokenUri(malformedTokenUri) - .build(); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(clientRegistration); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("[invalid_token_request] An error occurred while sending the Access Token Request:"); - } - @Test public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthenticationException() { String invalidTokenUri = "http://invalid-provider.com/oauth2/token"; @@ -258,7 +249,7 @@ public class DefaultClientCredentialsTokenResponseClientTests { assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("[invalid_token_request] An error occurred while sending the Access Token Request:"); + .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); } @Test @@ -278,7 +269,7 @@ public class DefaultClientCredentialsTokenResponseClientTests { assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("[invalid_token_request] An error occurred while sending the Access Token Request:"); + .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); } @Test @@ -305,7 +296,7 @@ public class DefaultClientCredentialsTokenResponseClientTests { assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("[server_error]"); + .hasMessage("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: 500 Server Error"); } private MockResponse jsonResponse(String json) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java new file mode 100644 index 0000000000..5c69692093 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.MultiValueMap; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; + +/** + * Tests for {@link OAuth2ClientCredentialsGrantRequestEntityConverter}. + * + * @author Joe Grandja + */ +public class OAuth2ClientCredentialsGrantRequestEntityConverterTests { + private OAuth2ClientCredentialsGrantRequestEntityConverter converter = new OAuth2ClientCredentialsGrantRequestEntityConverter(); + private OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest; + + @Before + public void setup() { + ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .scope("read", "write") + .tokenUri("https://provider.com/oauth2/token") + .build(); + this.clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenGrantRequestValidThenConverts() { + RequestEntity requestEntity = this.converter.convert(this.clientCredentialsGrantRequest); + + ClientRegistration clientRegistration = this.clientCredentialsGrantRequest.getClientRegistration(); + + assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); + assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( + clientRegistration.getProviderDetails().getTokenUri()); + + HttpHeaders headers = requestEntity.getHeaders(); + assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); + assertThat(headers.getContentType()).isEqualTo( + MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)).isEqualTo( + AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read write"); + } +}