Browse Source
Implement R2dbcReactiveOuath2AuthorizedClientService which persists the Oauth2AuthorizedClient in a sql database R2dbcReactiveOuath2AuthorizedClientService is using the spring-r2dbc module to persist/load Oauth2AuthorizedClient to/from a sql database Add optional depedency to the spring-r2dbc module Add test compile dependencies to r2dbc-h2 and r2dbc-test Closes gh-7890pull/9147/head
3 changed files with 771 additions and 0 deletions
@ -0,0 +1,387 @@
@@ -0,0 +1,387 @@
|
||||
/* |
||||
* Copyright 2002-2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package org.springframework.security.oauth2.client; |
||||
|
||||
import java.nio.ByteBuffer; |
||||
import java.nio.charset.StandardCharsets; |
||||
import java.time.Instant; |
||||
import java.time.LocalDateTime; |
||||
import java.time.ZoneOffset; |
||||
import java.util.Collections; |
||||
import java.util.HashMap; |
||||
import java.util.Map; |
||||
import java.util.Map.Entry; |
||||
import java.util.Set; |
||||
import java.util.function.BiFunction; |
||||
import java.util.function.Function; |
||||
|
||||
import io.r2dbc.spi.Row; |
||||
import io.r2dbc.spi.RowMetadata; |
||||
import reactor.core.publisher.Mono; |
||||
|
||||
import org.springframework.dao.DataRetrievalFailureException; |
||||
import org.springframework.r2dbc.core.DatabaseClient; |
||||
import org.springframework.r2dbc.core.DatabaseClient.GenericExecuteSpec; |
||||
import org.springframework.r2dbc.core.Parameter; |
||||
import org.springframework.security.core.Authentication; |
||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; |
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken; |
||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken; |
||||
import org.springframework.util.Assert; |
||||
import org.springframework.util.CollectionUtils; |
||||
import org.springframework.util.StringUtils; |
||||
|
||||
/** |
||||
* A R2DBC implementation of {@link ReactiveOAuth2AuthorizedClientService} that uses a |
||||
* {@link DatabaseClient} for {@link OAuth2AuthorizedClient} persistence. |
||||
* |
||||
* <p> |
||||
* <b>NOTE:</b> This {@code ReactiveOAuth2AuthorizedClientService} depends on the table |
||||
* definition described in |
||||
* "classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql" and |
||||
* therefore MUST be defined in the database schema. |
||||
* |
||||
* @author Ovidiu Popa |
||||
* @since 5.5 |
||||
* @see ReactiveOAuth2AuthorizedClientService |
||||
* @see OAuth2AuthorizedClient |
||||
* @see DatabaseClient |
||||
* |
||||
*/ |
||||
public class R2dbcReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService { |
||||
|
||||
// @formatter:off
|
||||
private static final String COLUMN_NAMES = |
||||
"client_registration_id, " + |
||||
"principal_name, " + |
||||
"access_token_type, " + |
||||
"access_token_value, " + |
||||
"access_token_issued_at, " + |
||||
"access_token_expires_at, " + |
||||
"access_token_scopes, " + |
||||
"refresh_token_value, " + |
||||
"refresh_token_issued_at"; |
||||
// @formatter:on
|
||||
|
||||
private static final String TABLE_NAME = "oauth2_authorized_client"; |
||||
|
||||
private static final String PK_FILTER = "client_registration_id = :clientRegistrationId AND principal_name = :principalName"; |
||||
|
||||
// @formatter:off
|
||||
private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME |
||||
+ " WHERE " + PK_FILTER; |
||||
// @formatter:on
|
||||
|
||||
// @formatter:off
|
||||
private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + " (" + COLUMN_NAMES + ")" + |
||||
"VALUES (:clientRegistrationId, :principalName, :accessTokenType, :accessTokenValue," + |
||||
" :accessTokenIssuedAt, :accessTokenExpiresAt, :accessTokenScopes, :refreshTokenValue," + |
||||
" :refreshTokenIssuedAt)"; |
||||
// @formatter:on
|
||||
|
||||
private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER; |
||||
|
||||
// @formatter:off
|
||||
private static final String UPDATE_AUTHORIZED_CLIENT_SQL = "UPDATE " + TABLE_NAME + |
||||
" SET access_token_type = :accessTokenType, " + |
||||
" access_token_value = :accessTokenValue, " + |
||||
" access_token_issued_at = :accessTokenIssuedAt," + |
||||
" access_token_expires_at = :accessTokenExpiresAt, " + |
||||
" access_token_scopes = :accessTokenScopes," + |
||||
" refresh_token_value = :refreshTokenValue, " + |
||||
" refresh_token_issued_at = :refreshTokenIssuedAt" + |
||||
" WHERE " + |
||||
PK_FILTER; |
||||
// @formatter:on
|
||||
|
||||
protected final DatabaseClient databaseClient; |
||||
|
||||
protected final ReactiveClientRegistrationRepository clientRegistrationRepository; |
||||
|
||||
protected Function<OAuth2AuthorizedClientHolder, Map<String, Parameter>> authorizedClientParametersMapper; |
||||
|
||||
protected BiFunction<Row, RowMetadata, OAuth2AuthorizedClientHolder> authorizedClientRowMapper; |
||||
|
||||
/** |
||||
* Constructs a {@code R2dbcReactiveOAuth2AuthorizedClientService} using the provided |
||||
* parameters. |
||||
* @param databaseClient the database client |
||||
* @param clientRegistrationRepository the repository of client registrations |
||||
*/ |
||||
public R2dbcReactiveOAuth2AuthorizedClientService(DatabaseClient databaseClient, |
||||
ReactiveClientRegistrationRepository clientRegistrationRepository) { |
||||
Assert.notNull(databaseClient, "databaseClient cannot be null"); |
||||
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); |
||||
this.databaseClient = databaseClient; |
||||
this.clientRegistrationRepository = clientRegistrationRepository; |
||||
this.authorizedClientParametersMapper = new OAuth2AuthorizedClientParametersMapper(); |
||||
this.authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper(); |
||||
} |
||||
|
||||
@Override |
||||
@SuppressWarnings("unchecked") |
||||
public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(String clientRegistrationId, |
||||
String principalName) { |
||||
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); |
||||
Assert.hasText(principalName, "principalName cannot be empty"); |
||||
|
||||
return (Mono<T>) this.databaseClient.sql(LOAD_AUTHORIZED_CLIENT_SQL) |
||||
.bind("clientRegistrationId", clientRegistrationId).bind("principalName", principalName) |
||||
.map(this.authorizedClientRowMapper).first().flatMap(this::getAuthorizedClient); |
||||
} |
||||
|
||||
private Mono<OAuth2AuthorizedClient> getAuthorizedClient(OAuth2AuthorizedClientHolder authorizedClientHolder) { |
||||
return this.clientRegistrationRepository.findByRegistrationId(authorizedClientHolder.getClientRegistrationId()) |
||||
.switchIfEmpty( |
||||
Mono.error(dataRetrievalFailureException(authorizedClientHolder.getClientRegistrationId()))) |
||||
.map((clientRegistration) -> new OAuth2AuthorizedClient(clientRegistration, |
||||
authorizedClientHolder.getPrincipalName(), authorizedClientHolder.getAccessToken(), |
||||
authorizedClientHolder.getRefreshToken())); |
||||
} |
||||
|
||||
private static Throwable dataRetrievalFailureException(String clientRegistrationId) { |
||||
return new DataRetrievalFailureException("The ClientRegistration with id '" + clientRegistrationId |
||||
+ "' exists in the data source, however, it was not found in the ReactiveClientRegistrationRepository."); |
||||
} |
||||
|
||||
@Override |
||||
public Mono<Void> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { |
||||
Assert.notNull(authorizedClient, "authorizedClient cannot be null"); |
||||
Assert.notNull(principal, "principal cannot be null"); |
||||
return this |
||||
.loadAuthorizedClient(authorizedClient.getClientRegistration().getRegistrationId(), principal.getName()) |
||||
.flatMap((dbAuthorizedClient) -> updateAuthorizedClient(authorizedClient, principal)) |
||||
.switchIfEmpty(Mono.defer(() -> insertAuthorizedClient(authorizedClient, principal))).then(); |
||||
} |
||||
|
||||
private Mono<Integer> updateAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { |
||||
GenericExecuteSpec executeSpec = this.databaseClient.sql(UPDATE_AUTHORIZED_CLIENT_SQL); |
||||
for (Entry<String, Parameter> entry : this.authorizedClientParametersMapper |
||||
.apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)).entrySet()) { |
||||
executeSpec = executeSpec.bind(entry.getKey(), entry.getValue()); |
||||
} |
||||
return executeSpec.fetch().rowsUpdated(); |
||||
} |
||||
|
||||
private Mono<Integer> insertAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { |
||||
GenericExecuteSpec executeSpec = this.databaseClient.sql(SAVE_AUTHORIZED_CLIENT_SQL); |
||||
for (Entry<String, Parameter> entry : this.authorizedClientParametersMapper |
||||
.apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)).entrySet()) { |
||||
executeSpec = executeSpec.bind(entry.getKey(), entry.getValue()); |
||||
} |
||||
return executeSpec.fetch().rowsUpdated(); |
||||
} |
||||
|
||||
@Override |
||||
public Mono<Void> removeAuthorizedClient(String clientRegistrationId, String principalName) { |
||||
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); |
||||
Assert.hasText(principalName, "principalName cannot be empty"); |
||||
return this.databaseClient.sql(REMOVE_AUTHORIZED_CLIENT_SQL).bind("clientRegistrationId", clientRegistrationId) |
||||
.bind("principalName", principalName).then(); |
||||
} |
||||
|
||||
/** |
||||
* Sets the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to |
||||
* a {@code Map} of {@link String} and {@link Parameter}. The default is |
||||
* {@link OAuth2AuthorizedClientParametersMapper}. |
||||
* @param authorizedClientParametersMapper the {@code Function} used for mapping |
||||
* {@link OAuth2AuthorizedClientHolder} to a {@code Map} of {@link String} and |
||||
* {@link Parameter} |
||||
*/ |
||||
public final void setAuthorizedClientParametersMapper( |
||||
Function<OAuth2AuthorizedClientHolder, Map<String, Parameter>> authorizedClientParametersMapper) { |
||||
Assert.notNull(authorizedClientParametersMapper, "authorizedClientParametersMapper cannot be null"); |
||||
this.authorizedClientParametersMapper = authorizedClientParametersMapper; |
||||
} |
||||
|
||||
/** |
||||
* Sets the {@link BiFunction} used for mapping the current {@code io.r2dbc.spi.Row} |
||||
* to {@link OAuth2AuthorizedClientHolder}. The default is |
||||
* {@link OAuth2AuthorizedClientRowMapper}. |
||||
* @param authorizedClientRowMapper the {@link BiFunction} used for mapping the |
||||
* current {@code io.r2dbc.spi.Row} to {@link OAuth2AuthorizedClientHolder} |
||||
*/ |
||||
public final void setAuthorizedClientRowMapper( |
||||
BiFunction<Row, RowMetadata, OAuth2AuthorizedClientHolder> authorizedClientRowMapper) { |
||||
Assert.notNull(authorizedClientRowMapper, "authorizedClientRowMapper cannot be null"); |
||||
this.authorizedClientRowMapper = authorizedClientRowMapper; |
||||
} |
||||
|
||||
/** |
||||
* A holder for {@link OAuth2AuthorizedClient} data and End-User |
||||
* {@link Authentication} (Resource Owner). |
||||
*/ |
||||
public static final class OAuth2AuthorizedClientHolder { |
||||
|
||||
private final String clientRegistrationId; |
||||
|
||||
private final String principalName; |
||||
|
||||
private final OAuth2AccessToken accessToken; |
||||
|
||||
private final OAuth2RefreshToken refreshToken; |
||||
|
||||
/** |
||||
* Constructs an {@code OAuth2AuthorizedClientHolder} using the provided |
||||
* parameters. |
||||
* @param authorizedClient the authorized client |
||||
* @param principal the End-User {@link Authentication} (Resource Owner) |
||||
*/ |
||||
public OAuth2AuthorizedClientHolder(OAuth2AuthorizedClient authorizedClient, Authentication principal) { |
||||
Assert.notNull(authorizedClient, "authorizedClient cannot be null"); |
||||
Assert.notNull(principal, "principal cannot be null"); |
||||
this.clientRegistrationId = authorizedClient.getClientRegistration().getRegistrationId(); |
||||
this.principalName = principal.getName(); |
||||
this.accessToken = authorizedClient.getAccessToken(); |
||||
this.refreshToken = authorizedClient.getRefreshToken(); |
||||
} |
||||
|
||||
/** |
||||
* Constructs an {@code OAuth2AuthorizedClientHolder} using the provided |
||||
* parameters. |
||||
* @param clientRegistrationId the client registration id |
||||
* @param principalName the principal name of the End-User (Resource Owner) |
||||
* @param accessToken the access token |
||||
* @param refreshToken the refresh token |
||||
*/ |
||||
public OAuth2AuthorizedClientHolder(String clientRegistrationId, String principalName, |
||||
OAuth2AccessToken accessToken, OAuth2RefreshToken refreshToken) { |
||||
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); |
||||
Assert.hasText(principalName, "principalName cannot be empty"); |
||||
Assert.notNull(accessToken, "accessToken cannot be null"); |
||||
this.clientRegistrationId = clientRegistrationId; |
||||
this.principalName = principalName; |
||||
this.accessToken = accessToken; |
||||
this.refreshToken = refreshToken; |
||||
} |
||||
|
||||
public String getClientRegistrationId() { |
||||
return this.clientRegistrationId; |
||||
} |
||||
|
||||
public String getPrincipalName() { |
||||
return this.principalName; |
||||
} |
||||
|
||||
public OAuth2AccessToken getAccessToken() { |
||||
return this.accessToken; |
||||
} |
||||
|
||||
public OAuth2RefreshToken getRefreshToken() { |
||||
return this.refreshToken; |
||||
} |
||||
|
||||
} |
||||
|
||||
/** |
||||
* The default {@code Function} that maps {@link OAuth2AuthorizedClientHolder} to a |
||||
* {@code Map} of {@link String} and {@link Parameter}. |
||||
*/ |
||||
public static class OAuth2AuthorizedClientParametersMapper |
||||
implements Function<OAuth2AuthorizedClientHolder, Map<String, Parameter>> { |
||||
|
||||
@Override |
||||
public Map<String, Parameter> apply(OAuth2AuthorizedClientHolder authorizedClientHolder) { |
||||
|
||||
final Map<String, Parameter> parameters = new HashMap<>(); |
||||
|
||||
final OAuth2AccessToken accessToken = authorizedClientHolder.getAccessToken(); |
||||
final OAuth2RefreshToken refreshToken = authorizedClientHolder.getRefreshToken(); |
||||
|
||||
parameters.put("clientRegistrationId", |
||||
Parameter.fromOrEmpty(authorizedClientHolder.getClientRegistrationId(), String.class)); |
||||
parameters.put("principalName", |
||||
Parameter.fromOrEmpty(authorizedClientHolder.getPrincipalName(), String.class)); |
||||
parameters.put("accessTokenType", |
||||
Parameter.fromOrEmpty(accessToken.getTokenType().getValue(), String.class)); |
||||
parameters.put("accessTokenValue", Parameter.fromOrEmpty( |
||||
ByteBuffer.wrap(accessToken.getTokenValue().getBytes(StandardCharsets.UTF_8)), ByteBuffer.class)); |
||||
parameters.put("accessTokenIssuedAt", Parameter.fromOrEmpty( |
||||
LocalDateTime.ofInstant(accessToken.getIssuedAt(), ZoneOffset.UTC), LocalDateTime.class)); |
||||
parameters.put("accessTokenExpiresAt", Parameter.fromOrEmpty( |
||||
LocalDateTime.ofInstant(accessToken.getExpiresAt(), ZoneOffset.UTC), LocalDateTime.class)); |
||||
String accessTokenScopes = null; |
||||
if (!CollectionUtils.isEmpty(accessToken.getScopes())) { |
||||
accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getScopes(), ","); |
||||
|
||||
} |
||||
parameters.put("accessTokenScopes", Parameter.fromOrEmpty(accessTokenScopes, String.class)); |
||||
ByteBuffer refreshTokenValue = null; |
||||
LocalDateTime refreshTokenIssuedAt = null; |
||||
if (refreshToken != null) { |
||||
refreshTokenValue = ByteBuffer.wrap(refreshToken.getTokenValue().getBytes(StandardCharsets.UTF_8)); |
||||
if (refreshToken.getIssuedAt() != null) { |
||||
refreshTokenIssuedAt = LocalDateTime.ofInstant(refreshToken.getIssuedAt(), ZoneOffset.UTC); |
||||
} |
||||
|
||||
} |
||||
|
||||
parameters.put("refreshTokenValue", Parameter.fromOrEmpty(refreshTokenValue, ByteBuffer.class)); |
||||
parameters.put("refreshTokenIssuedAt", Parameter.fromOrEmpty(refreshTokenIssuedAt, LocalDateTime.class)); |
||||
return parameters; |
||||
} |
||||
|
||||
} |
||||
|
||||
/** |
||||
* The default {@link BiFunction} that maps the current {@code io.r2dbc.spi.Row} to a |
||||
* {@link OAuth2AuthorizedClientHolder}. |
||||
*/ |
||||
public static class OAuth2AuthorizedClientRowMapper |
||||
implements BiFunction<Row, RowMetadata, OAuth2AuthorizedClientHolder> { |
||||
|
||||
@Override |
||||
public OAuth2AuthorizedClientHolder apply(Row row, RowMetadata rowMetadata) { |
||||
|
||||
String dbClientRegistrationId = row.get("client_registration_id", String.class); |
||||
OAuth2AccessToken.TokenType tokenType = null; |
||||
if (OAuth2AccessToken.TokenType.BEARER.getValue() |
||||
.equalsIgnoreCase(row.get("access_token_type", String.class))) { |
||||
tokenType = OAuth2AccessToken.TokenType.BEARER; |
||||
} |
||||
String tokenValue = new String(row.get("access_token_value", ByteBuffer.class).array(), |
||||
StandardCharsets.UTF_8); |
||||
Instant issuedAt = row.get("access_token_issued_at", LocalDateTime.class).toInstant(ZoneOffset.UTC); |
||||
Instant expiresAt = row.get("access_token_expires_at", LocalDateTime.class).toInstant(ZoneOffset.UTC); |
||||
|
||||
Set<String> scopes = Collections.emptySet(); |
||||
String accessTokenScopes = row.get("access_token_scopes", String.class); |
||||
if (accessTokenScopes != null) { |
||||
scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); |
||||
} |
||||
final OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, issuedAt, expiresAt, |
||||
scopes); |
||||
|
||||
OAuth2RefreshToken refreshToken = null; |
||||
ByteBuffer refreshTokenValue = row.get("refresh_token_value", ByteBuffer.class); |
||||
if (refreshTokenValue != null) { |
||||
tokenValue = new String(refreshTokenValue.array(), StandardCharsets.UTF_8); |
||||
issuedAt = null; |
||||
LocalDateTime refreshTokenIssuedAt = row.get("refresh_token_issued_at", LocalDateTime.class); |
||||
if (refreshTokenIssuedAt != null) { |
||||
issuedAt = refreshTokenIssuedAt.toInstant(ZoneOffset.UTC); |
||||
} |
||||
refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt); |
||||
} |
||||
|
||||
String dbPrincipalName = row.get("principal_name", String.class); |
||||
return new OAuth2AuthorizedClientHolder(dbClientRegistrationId, dbPrincipalName, accessToken, refreshToken); |
||||
} |
||||
|
||||
} |
||||
|
||||
} |
||||
@ -0,0 +1,381 @@
@@ -0,0 +1,381 @@
|
||||
/* |
||||
* Copyright 2002-2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package org.springframework.security.oauth2.client; |
||||
|
||||
import io.r2dbc.h2.H2ConnectionFactory; |
||||
import io.r2dbc.spi.ConnectionFactory; |
||||
import io.r2dbc.spi.Result; |
||||
import org.junit.Before; |
||||
import org.junit.Test; |
||||
import reactor.core.publisher.Flux; |
||||
import reactor.core.publisher.Mono; |
||||
import reactor.test.StepVerifier; |
||||
|
||||
import org.springframework.core.io.ClassPathResource; |
||||
import org.springframework.dao.DataRetrievalFailureException; |
||||
import org.springframework.r2dbc.connection.init.CompositeDatabasePopulator; |
||||
import org.springframework.r2dbc.connection.init.ConnectionFactoryInitializer; |
||||
import org.springframework.r2dbc.connection.init.ResourceDatabasePopulator; |
||||
import org.springframework.r2dbc.core.DatabaseClient; |
||||
import org.springframework.security.authentication.TestingAuthenticationToken; |
||||
import org.springframework.security.core.Authentication; |
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration; |
||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; |
||||
import org.springframework.security.oauth2.client.registration.TestClientRegistrations; |
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken; |
||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken; |
||||
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; |
||||
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; |
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat; |
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; |
||||
import static org.mockito.ArgumentMatchers.any; |
||||
import static org.mockito.ArgumentMatchers.anyString; |
||||
import static org.mockito.BDDMockito.given; |
||||
import static org.mockito.Mockito.mock; |
||||
|
||||
/** |
||||
* Tests for {@link R2dbcReactiveOAuth2AuthorizedClientService} |
||||
* |
||||
* @author Ovidiu Popa |
||||
* |
||||
*/ |
||||
public class R2dbcReactiveOAuth2AuthorizedClientServiceTests { |
||||
|
||||
private static final String OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/client/oauth2-client-schema.sql"; |
||||
|
||||
private ClientRegistration clientRegistration; |
||||
|
||||
private ReactiveClientRegistrationRepository clientRegistrationRepository; |
||||
|
||||
private DatabaseClient databaseClient; |
||||
|
||||
private static int principalId = 1000; |
||||
|
||||
private R2dbcReactiveOAuth2AuthorizedClientService authorizedClientService; |
||||
|
||||
@Before |
||||
public void setUp() { |
||||
final ConnectionFactory connectionFactory = createDb(); |
||||
this.clientRegistration = TestClientRegistrations.clientRegistration().build(); |
||||
this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); |
||||
given(this.clientRegistrationRepository.findByRegistrationId(anyString())) |
||||
.willReturn(Mono.just(this.clientRegistration)); |
||||
this.databaseClient = DatabaseClient.create(connectionFactory); |
||||
this.authorizedClientService = new R2dbcReactiveOAuth2AuthorizedClientService(this.databaseClient, |
||||
this.clientRegistrationRepository); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenDatabaseClientIsNullThenThrowIllegalArgumentException() { |
||||
assertThatExceptionOfType(IllegalArgumentException.class) |
||||
.isThrownBy( |
||||
() -> new R2dbcReactiveOAuth2AuthorizedClientService(null, this.clientRegistrationRepository)) |
||||
.withMessageContaining("databaseClient cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { |
||||
assertThatExceptionOfType(IllegalArgumentException.class) |
||||
.isThrownBy(() -> new R2dbcReactiveOAuth2AuthorizedClientService(this.databaseClient, null)) |
||||
.withMessageContaining("clientRegistrationRepository cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { |
||||
assertThatExceptionOfType(IllegalArgumentException.class) |
||||
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName")) |
||||
.withMessageContaining("clientRegistrationId cannot be empty"); |
||||
} |
||||
|
||||
@Test |
||||
public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { |
||||
assertThatExceptionOfType(IllegalArgumentException.class) |
||||
.isThrownBy(() -> this.authorizedClientService |
||||
.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) |
||||
.withMessageContaining("principalName cannot be empty"); |
||||
} |
||||
|
||||
@Test |
||||
public void loadAuthorizedClientWhenDoesNotExistThenReturnNull() { |
||||
this.authorizedClientService.loadAuthorizedClient("registration-not-found", "principalName") |
||||
.as(StepVerifier::create).expectNextCount(0).verifyComplete(); |
||||
} |
||||
|
||||
@Test |
||||
public void loadAuthorizedClientWhenExistsThenReturnAuthorizedClient() { |
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); |
||||
this.authorizedClientService.saveAuthorizedClient(expected, principal).as(StepVerifier::create) |
||||
.verifyComplete(); |
||||
|
||||
this.authorizedClientService |
||||
.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) |
||||
.as(StepVerifier::create).assertNext((authorizedClient) -> { |
||||
assertThat(authorizedClient).isNotNull(); |
||||
assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); |
||||
assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenType()) |
||||
.isEqualTo(expected.getAccessToken().getTokenType()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenValue()) |
||||
.isEqualTo(expected.getAccessToken().getTokenValue()); |
||||
assertThat(authorizedClient.getAccessToken().getIssuedAt()) |
||||
.isEqualTo(expected.getAccessToken().getIssuedAt()); |
||||
assertThat(authorizedClient.getAccessToken().getExpiresAt()) |
||||
.isEqualTo(expected.getAccessToken().getExpiresAt()); |
||||
assertThat(authorizedClient.getAccessToken().getScopes()) |
||||
.isEqualTo(expected.getAccessToken().getScopes()); |
||||
assertThat(authorizedClient.getRefreshToken().getTokenValue()) |
||||
.isEqualTo(expected.getRefreshToken().getTokenValue()); |
||||
assertThat(authorizedClient.getRefreshToken().getIssuedAt()) |
||||
.isEqualTo(expected.getRefreshToken().getIssuedAt()); |
||||
}).verifyComplete(); |
||||
} |
||||
|
||||
@Test |
||||
public void loadAuthorizedClientWhenExistsButNotFoundInClientRegistrationRepositoryThenThrowDataRetrievalFailureException() { |
||||
given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.empty()); |
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
this.authorizedClientService.saveAuthorizedClient(expected, principal).as(StepVerifier::create) |
||||
.verifyComplete(); |
||||
|
||||
this.authorizedClientService |
||||
.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) |
||||
.as(StepVerifier::create) |
||||
.verifyErrorSatisfies((exception) -> assertThat(exception) |
||||
.isInstanceOf(DataRetrievalFailureException.class) |
||||
.hasMessage("The ClientRegistration with id '" + this.clientRegistration.getRegistrationId() |
||||
+ "' exists in the data source, however, it was not found in the ReactiveClientRegistrationRepository.")); |
||||
} |
||||
|
||||
@Test |
||||
public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { |
||||
Authentication principal = createPrincipal(); |
||||
|
||||
assertThatExceptionOfType(IllegalArgumentException.class) |
||||
.isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, principal)) |
||||
.withMessageContaining("authorizedClient cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void saveAuthorizedClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { |
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); |
||||
assertThatExceptionOfType(IllegalArgumentException.class) |
||||
.isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, null)) |
||||
.withMessageContaining("principal cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void saveAuthorizedClientWhenSaveThenLoadReturnsSaved() { |
||||
Authentication principal = createPrincipal(); |
||||
final OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
this.authorizedClientService.saveAuthorizedClient(expected, principal).as(StepVerifier::create) |
||||
.verifyComplete(); |
||||
|
||||
this.authorizedClientService |
||||
.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) |
||||
.as(StepVerifier::create).assertNext((authorizedClient) -> { |
||||
assertThat(authorizedClient).isNotNull(); |
||||
assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); |
||||
assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenType()) |
||||
.isEqualTo(expected.getAccessToken().getTokenType()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenValue()) |
||||
.isEqualTo(expected.getAccessToken().getTokenValue()); |
||||
assertThat(authorizedClient.getAccessToken().getIssuedAt()) |
||||
.isEqualTo(expected.getAccessToken().getIssuedAt()); |
||||
assertThat(authorizedClient.getAccessToken().getExpiresAt()) |
||||
.isEqualTo(expected.getAccessToken().getExpiresAt()); |
||||
assertThat(authorizedClient.getAccessToken().getScopes()) |
||||
.isEqualTo(expected.getAccessToken().getScopes()); |
||||
assertThat(authorizedClient.getRefreshToken().getTokenValue()) |
||||
.isEqualTo(expected.getRefreshToken().getTokenValue()); |
||||
assertThat(authorizedClient.getRefreshToken().getIssuedAt()) |
||||
.isEqualTo(expected.getRefreshToken().getIssuedAt()); |
||||
}).verifyComplete(); |
||||
|
||||
// Test save/load of NOT NULL attributes only
|
||||
principal = createPrincipal(); |
||||
OAuth2AuthorizedClient updatedExpectedPrincipal = createAuthorizedClient(principal, this.clientRegistration, |
||||
true); |
||||
this.authorizedClientService.saveAuthorizedClient(updatedExpectedPrincipal, principal).as(StepVerifier::create) |
||||
.verifyComplete(); |
||||
|
||||
this.authorizedClientService |
||||
.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) |
||||
.as(StepVerifier::create).assertNext((authorizedClient) -> { |
||||
assertThat(authorizedClient).isNotNull(); |
||||
assertThat(authorizedClient.getClientRegistration()) |
||||
.isEqualTo(updatedExpectedPrincipal.getClientRegistration()); |
||||
assertThat(authorizedClient.getPrincipalName()) |
||||
.isEqualTo(updatedExpectedPrincipal.getPrincipalName()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenType()) |
||||
.isEqualTo(updatedExpectedPrincipal.getAccessToken().getTokenType()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenValue()) |
||||
.isEqualTo(updatedExpectedPrincipal.getAccessToken().getTokenValue()); |
||||
assertThat(authorizedClient.getAccessToken().getIssuedAt()) |
||||
.isEqualTo(updatedExpectedPrincipal.getAccessToken().getIssuedAt()); |
||||
assertThat(authorizedClient.getAccessToken().getExpiresAt()) |
||||
.isEqualTo(updatedExpectedPrincipal.getAccessToken().getExpiresAt()); |
||||
assertThat(authorizedClient.getAccessToken().getScopes()).isEmpty(); |
||||
assertThat(authorizedClient.getRefreshToken()).isNull(); |
||||
}).verifyComplete(); |
||||
} |
||||
|
||||
@Test |
||||
public void saveAuthorizedClientWhenSaveClientWithExistingPrimaryKeyThenUpdate() { |
||||
// Given a saved authorized client
|
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); |
||||
this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal).as(StepVerifier::create) |
||||
.verifyComplete(); |
||||
|
||||
// When a client with the same principal and registration id is saved
|
||||
OAuth2AuthorizedClient updatedAuthorizedClient = createAuthorizedClient(principal, this.clientRegistration); |
||||
this.authorizedClientService.saveAuthorizedClient(updatedAuthorizedClient, principal).as(StepVerifier::create) |
||||
.verifyComplete(); |
||||
|
||||
// Then the saved client is updated
|
||||
this.authorizedClientService |
||||
.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) |
||||
.as(StepVerifier::create).assertNext((savedClient) -> { |
||||
assertThat(savedClient).isNotNull(); |
||||
assertThat(savedClient.getClientRegistration()) |
||||
.isEqualTo(updatedAuthorizedClient.getClientRegistration()); |
||||
assertThat(savedClient.getPrincipalName()).isEqualTo(updatedAuthorizedClient.getPrincipalName()); |
||||
assertThat(savedClient.getAccessToken().getTokenType()) |
||||
.isEqualTo(updatedAuthorizedClient.getAccessToken().getTokenType()); |
||||
assertThat(savedClient.getAccessToken().getTokenValue()) |
||||
.isEqualTo(updatedAuthorizedClient.getAccessToken().getTokenValue()); |
||||
assertThat(savedClient.getAccessToken().getIssuedAt()) |
||||
.isEqualTo(updatedAuthorizedClient.getAccessToken().getIssuedAt()); |
||||
assertThat(savedClient.getAccessToken().getExpiresAt()) |
||||
.isEqualTo(updatedAuthorizedClient.getAccessToken().getExpiresAt()); |
||||
assertThat(savedClient.getAccessToken().getScopes()) |
||||
.isEqualTo(updatedAuthorizedClient.getAccessToken().getScopes()); |
||||
assertThat(savedClient.getRefreshToken().getTokenValue()) |
||||
.isEqualTo(updatedAuthorizedClient.getRefreshToken().getTokenValue()); |
||||
assertThat(savedClient.getRefreshToken().getIssuedAt()) |
||||
.isEqualTo(updatedAuthorizedClient.getRefreshToken().getIssuedAt()); |
||||
}); |
||||
} |
||||
|
||||
@Test |
||||
public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { |
||||
assertThatExceptionOfType(IllegalArgumentException.class) |
||||
.isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, "principalName")) |
||||
.withMessageContaining("clientRegistrationId cannot be empty"); |
||||
} |
||||
|
||||
@Test |
||||
public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { |
||||
assertThatExceptionOfType(IllegalArgumentException.class) |
||||
.isThrownBy(() -> this.authorizedClientService |
||||
.removeAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) |
||||
.withMessageContaining("principalName cannot be empty"); |
||||
} |
||||
|
||||
@Test |
||||
public void removeAuthorizedClientWhenExistsThenRemoved() { |
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal).as(StepVerifier::create) |
||||
.verifyComplete(); |
||||
|
||||
this.authorizedClientService |
||||
.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) |
||||
.as(StepVerifier::create).assertNext((dbAuthorizedClient) -> assertThat(dbAuthorizedClient).isNotNull()) |
||||
.verifyComplete(); |
||||
|
||||
this.authorizedClientService |
||||
.removeAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) |
||||
.as(StepVerifier::create).verifyComplete(); |
||||
|
||||
this.authorizedClientService |
||||
.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) |
||||
.as(StepVerifier::create).expectNextCount(0).verifyComplete(); |
||||
} |
||||
|
||||
@Test |
||||
public void setAuthorizedClientRowMapperWhenNullThenThrowIllegalArgumentException() { |
||||
assertThatExceptionOfType(IllegalArgumentException.class) |
||||
.isThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null)) |
||||
.withMessageContaining("authorizedClientRowMapper cannot be nul"); |
||||
} |
||||
|
||||
@Test |
||||
public void setAuthorizedClientParametersMapperWhenNullThenThrowIllegalArgumentException() { |
||||
assertThatExceptionOfType(IllegalArgumentException.class) |
||||
.isThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null)) |
||||
.withMessageContaining("authorizedClientParametersMapper cannot be nul"); |
||||
} |
||||
|
||||
private static ConnectionFactory createDb() { |
||||
ConnectionFactory connectionFactory = H2ConnectionFactory.inMemory("oauth-test"); |
||||
|
||||
Mono.from(connectionFactory.create()) |
||||
.flatMapMany((connection) -> Flux |
||||
.from(connection.createStatement("drop table oauth2_authorized_client").execute()) |
||||
.flatMap(Result::getRowsUpdated).onErrorResume((e) -> Mono.empty()) |
||||
.thenMany(connection.close())) |
||||
.as(StepVerifier::create).verifyComplete(); |
||||
ConnectionFactoryInitializer createDb = createDb(OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE); |
||||
createDb.setConnectionFactory(connectionFactory); |
||||
createDb.afterPropertiesSet(); |
||||
return connectionFactory; |
||||
} |
||||
|
||||
private static ConnectionFactoryInitializer createDb(String schema) { |
||||
ConnectionFactoryInitializer initializer = new ConnectionFactoryInitializer(); |
||||
|
||||
CompositeDatabasePopulator populator = new CompositeDatabasePopulator(); |
||||
populator.addPopulators(new ResourceDatabasePopulator(new ClassPathResource(schema))); |
||||
initializer.setDatabasePopulator(populator); |
||||
return initializer; |
||||
} |
||||
|
||||
private static Authentication createPrincipal() { |
||||
return new TestingAuthenticationToken("principal-" + principalId++, "password"); |
||||
} |
||||
|
||||
private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, |
||||
ClientRegistration clientRegistration) { |
||||
return createAuthorizedClient(principal, clientRegistration, false); |
||||
} |
||||
|
||||
private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, |
||||
ClientRegistration clientRegistration, boolean requiredAttributesOnly) { |
||||
OAuth2AccessToken accessToken; |
||||
if (!requiredAttributesOnly) { |
||||
accessToken = TestOAuth2AccessTokens.scopes("read", "write"); |
||||
} |
||||
else { |
||||
accessToken = TestOAuth2AccessTokens.noScopes(); |
||||
} |
||||
OAuth2RefreshToken refreshToken = null; |
||||
if (!requiredAttributesOnly) { |
||||
refreshToken = TestOAuth2RefreshTokens.refreshToken(); |
||||
} |
||||
return new OAuth2AuthorizedClient(clientRegistration, principal.getName(), accessToken, refreshToken); |
||||
} |
||||
|
||||
} |
||||
Loading…
Reference in new issue