diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClient.java index f92c6cf36f..1f7833087b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClient.java @@ -15,38 +15,21 @@ */ package org.springframework.security.oauth2.client.endpoint; -import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials; - -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.Map; -import java.util.Set; - -import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; -import org.springframework.util.CollectionUtils; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ExchangeFilterFunctions; import org.springframework.web.reactive.function.client.WebClient; - -import com.nimbusds.oauth2.sdk.AccessTokenResponse; -import com.nimbusds.oauth2.sdk.ErrorObject; -import com.nimbusds.oauth2.sdk.ParseException; -import com.nimbusds.oauth2.sdk.TokenErrorResponse; -import com.nimbusds.oauth2.sdk.TokenResponse; -import com.nimbusds.oauth2.sdk.token.AccessToken; - -import net.minidev.json.JSONObject; import reactor.core.publisher.Mono; +import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; +import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials; + /** * An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges" * an authorization code credential for an access token credential @@ -65,8 +48,6 @@ import reactor.core.publisher.Mono; * @see Section 4.1.4 Access Token Response (Authorization Code Grant) */ public class NimbusReactiveAuthorizationCodeTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient { - private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; - private WebClient webClient = WebClient.builder() .filter(ExchangeFilterFunctions.basicAuthentication()) .build(); @@ -87,52 +68,15 @@ public class NimbusReactiveAuthorizationCodeTokenResponseClient implements React .accept(MediaType.APPLICATION_JSON) .attributes(basicAuthenticationCredentials(clientRegistration.getClientId(), clientRegistration.getClientSecret())) .body(body) - .retrieve() - .onStatus(s -> false, response -> { - throw new IllegalStateException("Disabled Status Handlers"); - }) - .bodyToMono(new ParameterizedTypeReference>() {}) - .map(json -> parse(json)) - .flatMap(tokenResponse -> accessTokenResponse(tokenResponse)) - .map(accessTokenResponse -> { - AccessToken accessToken = accessTokenResponse.getTokens().getAccessToken(); - OAuth2AccessToken.TokenType accessTokenType = null; - if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( - accessToken.getType().getValue())) { - accessTokenType = OAuth2AccessToken.TokenType.BEARER; - } - long expiresIn = accessToken.getLifetime(); - - // As per spec, in section 5.1 Successful Access Token Response - // https://tools.ietf.org/html/rfc6749#section-5.1 - // If AccessTokenResponse.scope is empty, then default to the scope - // originally requested by the client in the Authorization Request - Set scopes; - if (CollectionUtils.isEmpty( - accessToken.getScope())) { - scopes = new LinkedHashSet<>( - authorizationExchange.getAuthorizationRequest().getScopes()); - } - else { - scopes = new LinkedHashSet<>( - accessToken.getScope().toStringList()); - } - - String refreshToken = null; - if (accessTokenResponse.getTokens().getRefreshToken() != null) { - refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue(); - } - - Map additionalParameters = new LinkedHashMap<>( - accessTokenResponse.getCustomParameters()); - - return OAuth2AccessTokenResponse.withToken(accessToken.getValue()) - .tokenType(accessTokenType) - .expiresIn(expiresIn) - .scopes(scopes) - .refreshToken(refreshToken) - .additionalParameters(additionalParameters) + .exchange() + .flatMap(response -> response.body(oauth2AccessTokenResponse())) + .map(response -> { + if (response.getAccessToken().getScopes().isEmpty()) { + response = OAuth2AccessTokenResponse.withResponse(response) + .scopes(authorizationExchange.getAuthorizationRequest().getScopes()) .build(); + } + return response; }); }); } @@ -148,30 +92,4 @@ public class NimbusReactiveAuthorizationCodeTokenResponseClient implements React } return body; } - - private static Mono accessTokenResponse(TokenResponse tokenResponse) { - if (tokenResponse.indicatesSuccess()) { - return Mono.just(tokenResponse) - .cast(AccessTokenResponse.class); - } - TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse; - ErrorObject errorObject = tokenErrorResponse.getErrorObject(); - OAuth2Error oauth2Error = new OAuth2Error(errorObject.getCode(), - errorObject.getDescription(), (errorObject.getURI() != null ? - errorObject.getURI().toString() : - null)); - - return Mono.error(new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString())); - } - - private static TokenResponse parse(Map json) { - try { - return TokenResponse.parse(new JSONObject(json)); - } - catch (ParseException pe) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred parsing the Access Token response: " + pe.getMessage(), null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), pe); - } - } } diff --git a/oauth2/oauth2-core/spring-security-oauth2-core.gradle b/oauth2/oauth2-core/spring-security-oauth2-core.gradle index 0a477bf7c3..bc66851194 100644 --- a/oauth2/oauth2-core/spring-security-oauth2-core.gradle +++ b/oauth2/oauth2-core/spring-security-oauth2-core.gradle @@ -4,5 +4,9 @@ dependencies { compile project(':spring-security-core') compile springCoreDependency + optional 'com.fasterxml.jackson.core:jackson-databind' + optional 'com.nimbusds:oauth2-oidc-sdk' + optional 'org.springframework:spring-webflux' + testCompile powerMock2Dependencies } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java new file mode 100644 index 0000000000..a14287eb12 --- /dev/null +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.core.web.reactive.function; + +import com.nimbusds.oauth2.sdk.AccessTokenResponse; +import com.nimbusds.oauth2.sdk.ErrorObject; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.TokenErrorResponse; +import com.nimbusds.oauth2.sdk.TokenResponse; +import com.nimbusds.oauth2.sdk.token.AccessToken; +import net.minidev.json.JSONObject; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.reactive.function.BodyExtractors; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +/** + * Provides a way to create an {@link OAuth2AccessTokenResponse} from a {@link ReactiveHttpInputMessage} + * @author Rob Winch + * @since 5.1 + */ +class OAuth2AccessTokenResponseBodyExtractor + implements BodyExtractor, ReactiveHttpInputMessage> { + + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; + + OAuth2AccessTokenResponseBodyExtractor() {} + + @Override + public Mono extract(ReactiveHttpInputMessage inputMessage, + Context context) { + ParameterizedTypeReference> type = new ParameterizedTypeReference>() {}; + BodyExtractor>, ReactiveHttpInputMessage> delegate = BodyExtractors.toMono(type); + return delegate.extract(inputMessage, context) + .map(json -> parse(json)) + .flatMap(OAuth2AccessTokenResponseBodyExtractor::oauth2AccessTokenResponse) + .map(OAuth2AccessTokenResponseBodyExtractor::oauth2AccessTokenResponse); + } + + private static TokenResponse parse(Map json) { + try { + return TokenResponse.parse(new JSONObject(json)); + } + catch (ParseException pe) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred parsing the Access Token response: " + pe.getMessage(), null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), pe); + } + } + + private static Mono oauth2AccessTokenResponse(TokenResponse tokenResponse) { + if (tokenResponse.indicatesSuccess()) { + return Mono.just(tokenResponse) + .cast(AccessTokenResponse.class); + } + TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse; + ErrorObject errorObject = tokenErrorResponse.getErrorObject(); + OAuth2Error oauth2Error = new OAuth2Error(errorObject.getCode(), + errorObject.getDescription(), (errorObject.getURI() != null ? + errorObject.getURI().toString() : + null)); + + return Mono.error(new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString())); + } + + private static OAuth2AccessTokenResponse oauth2AccessTokenResponse(AccessTokenResponse accessTokenResponse) { + AccessToken accessToken = accessTokenResponse.getTokens().getAccessToken(); + OAuth2AccessToken.TokenType accessTokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue() + .equalsIgnoreCase(accessToken.getType().getValue())) { + accessTokenType = OAuth2AccessToken.TokenType.BEARER; + } + long expiresIn = accessToken.getLifetime(); + + Set scopes = accessToken.getScope() == null ? + Collections.emptySet() : new LinkedHashSet<>(accessToken.getScope().toStringList()); + + String refreshToken = null; + if (accessTokenResponse.getTokens().getRefreshToken() != null) { + refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue(); + } + + Map additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters()); + + return OAuth2AccessTokenResponse.withToken(accessToken.getValue()).tokenType(accessTokenType).expiresIn(expiresIn).scopes(scopes) + .refreshToken(refreshToken).additionalParameters(additionalParameters).build(); + } +} diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractors.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractors.java new file mode 100644 index 0000000000..fbffe082eb --- /dev/null +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractors.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.core.web.reactive.function; + +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.web.reactive.function.BodyExtractor; +import reactor.core.publisher.Mono; + +/** + * Static factory methods for OAuth2 {@link BodyExtractor} implementations. + * @author Rob Winch + * @since 5.1 + */ +public abstract class OAuth2BodyExtractors { + + /** + * Extractor to decode an {@link OAuth2AccessTokenResponse} + * @return a BodyExtractor for {@link OAuth2AccessTokenResponse} + */ + public static BodyExtractor, ReactiveHttpInputMessage> oauth2AccessTokenResponse() { + return new OAuth2AccessTokenResponseBodyExtractor(); + } + + private OAuth2BodyExtractors() {} +} diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java new file mode 100644 index 0000000000..0e46b5d1da --- /dev/null +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.core.web.reactive.function; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.core.codec.ByteBufferDecoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.http.codec.DecoderHttpMessageReader; +import org.springframework.http.codec.FormHttpMessageReader; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.json.Jackson2JsonDecoder; +import org.springframework.http.codec.xml.Jaxb2XmlDecoder; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.mock.http.client.reactive.MockClientHttpResponse; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.web.reactive.function.BodyExtractor; +import reactor.core.publisher.Mono; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * @author Rob Winch + * @since 5.1 + */ +public class OAuth2BodyExtractorsTests { + + private BodyExtractor.Context context; + + private Map hints; + + @Before + public void createContext() { + final List> messageReaders = new ArrayList<>(); + messageReaders.add(new DecoderHttpMessageReader<>(new ByteBufferDecoder())); + messageReaders.add(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); + messageReaders.add(new DecoderHttpMessageReader<>(new Jaxb2XmlDecoder())); + messageReaders.add(new DecoderHttpMessageReader<>(new Jackson2JsonDecoder())); + messageReaders.add(new FormHttpMessageReader()); + + this.hints = new HashMap(); + this.context = new BodyExtractor.Context() { + @Override + public List> messageReaders() { + return messageReaders; + } + + @Override + public Optional serverResponse() { + return Optional.empty(); + } + + @Override + public Map hints() { + return OAuth2BodyExtractorsTests.this.hints; + } + }; + } + + @Test + public void oauth2AccessTokenResponseWhenInvalidJsonThenException() { + BodyExtractor, ReactiveHttpInputMessage> extractor = OAuth2BodyExtractors + .oauth2AccessTokenResponse(); + + MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); + response.getHeaders().setContentType(MediaType.APPLICATION_JSON); + response.setBody("{"); + + Mono result = extractor.extract(response, this.context); + + assertThatCode(() -> result.block()) + .isInstanceOf(RuntimeException.class); + } + + @Test + public void oauth2AccessTokenResponseWhenValidThenCreated() throws Exception { + BodyExtractor, ReactiveHttpInputMessage> extractor = OAuth2BodyExtractors + .oauth2AccessTokenResponse(); + + MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); + response.getHeaders().setContentType(MediaType.APPLICATION_JSON); + response.setBody("{\n" + + " \"access_token\":\"2YotnFZFEjr1zCsicMWpAA\",\n" + + " \"token_type\":\"Bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"tGzv3JOkF0XG5Qx2TlKWIA\",\n" + + " \"example_parameter\":\"example_value\"\n" + + " }"); + + Instant now = Instant.now(); + OAuth2AccessTokenResponse result = extractor.extract(response, this.context).block(); + + assertThat(result.getAccessToken().getTokenValue()).isEqualTo("2YotnFZFEjr1zCsicMWpAA"); + assertThat(result.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(result.getAccessToken().getExpiresAt()).isBetween(now.plusSeconds(3600), now.plusSeconds(3600 + 2)); + assertThat(result.getRefreshToken().getTokenValue()).isEqualTo("tGzv3JOkF0XG5Qx2TlKWIA"); + assertThat(result.getAdditionalParameters()).containsEntry("example_parameter", "example_value"); + } +}