5 changed files with 815 additions and 0 deletions
@ -0,0 +1,312 @@
@@ -0,0 +1,312 @@
|
||||
/* |
||||
* Copyright 2002-2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.oauth2.client; |
||||
|
||||
import org.springframework.dao.DataRetrievalFailureException; |
||||
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; |
||||
import org.springframework.jdbc.core.JdbcOperations; |
||||
import org.springframework.jdbc.core.PreparedStatementSetter; |
||||
import org.springframework.jdbc.core.RowMapper; |
||||
import org.springframework.jdbc.core.SqlParameterValue; |
||||
import org.springframework.security.core.Authentication; |
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration; |
||||
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; |
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken; |
||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken; |
||||
import org.springframework.util.Assert; |
||||
import org.springframework.util.CollectionUtils; |
||||
import org.springframework.util.StringUtils; |
||||
|
||||
import java.nio.charset.StandardCharsets; |
||||
import java.sql.ResultSet; |
||||
import java.sql.SQLException; |
||||
import java.sql.Timestamp; |
||||
import java.sql.Types; |
||||
import java.time.Instant; |
||||
import java.util.ArrayList; |
||||
import java.util.Collections; |
||||
import java.util.List; |
||||
import java.util.Set; |
||||
import java.util.function.Function; |
||||
|
||||
/** |
||||
* A JDBC implementation of an {@link OAuth2AuthorizedClientService} |
||||
* that uses a {@link JdbcOperations} for {@link OAuth2AuthorizedClient} persistence. |
||||
* |
||||
* <p> |
||||
* <b>NOTE:</b> This {@code OAuth2AuthorizedClientService} depends on the table definition |
||||
* described in "classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql" |
||||
* and therefore MUST be defined in the database schema. |
||||
* |
||||
* @author Joe Grandja |
||||
* @since 5.3 |
||||
* @see OAuth2AuthorizedClientService |
||||
* @see OAuth2AuthorizedClient |
||||
* @see JdbcOperations |
||||
* @see RowMapper |
||||
*/ |
||||
public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService { |
||||
private static final String COLUMN_NAMES = |
||||
"client_registration_id, " + |
||||
"principal_name, " + |
||||
"access_token_type, " + |
||||
"access_token_value, " + |
||||
"access_token_issued_at, " + |
||||
"access_token_expires_at, " + |
||||
"access_token_scopes, " + |
||||
"refresh_token_value, " + |
||||
"refresh_token_issued_at"; |
||||
private static final String TABLE_NAME = "oauth2_authorized_client"; |
||||
private static final String PK_FILTER = "client_registration_id = ? AND principal_name = ?"; |
||||
private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + |
||||
" FROM " + TABLE_NAME + " WHERE " + PK_FILTER; |
||||
private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + |
||||
" (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; |
||||
private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + |
||||
" WHERE " + PK_FILTER; |
||||
protected final JdbcOperations jdbcOperations; |
||||
protected RowMapper<OAuth2AuthorizedClient> authorizedClientRowMapper; |
||||
protected Function<OAuth2AuthorizedClientHolder, List<SqlParameterValue>> authorizedClientParametersMapper; |
||||
|
||||
/** |
||||
* Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided parameters. |
||||
* |
||||
* @param jdbcOperations the JDBC operations |
||||
* @param clientRegistrationRepository the repository of client registrations |
||||
*/ |
||||
public JdbcOAuth2AuthorizedClientService( |
||||
JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) { |
||||
|
||||
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); |
||||
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); |
||||
this.jdbcOperations = jdbcOperations; |
||||
this.authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository); |
||||
this.authorizedClientParametersMapper = new OAuth2AuthorizedClientParametersMapper(); |
||||
} |
||||
|
||||
@Override |
||||
@SuppressWarnings("unchecked") |
||||
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, String principalName) { |
||||
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); |
||||
Assert.hasText(principalName, "principalName cannot be empty"); |
||||
|
||||
SqlParameterValue[] parameters = new SqlParameterValue[] { |
||||
new SqlParameterValue(Types.VARCHAR, clientRegistrationId), |
||||
new SqlParameterValue(Types.VARCHAR, principalName) |
||||
}; |
||||
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); |
||||
|
||||
List<OAuth2AuthorizedClient> result = this.jdbcOperations.query( |
||||
LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper); |
||||
|
||||
return !result.isEmpty() ? (T) result.get(0) : null; |
||||
} |
||||
|
||||
@Override |
||||
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { |
||||
Assert.notNull(authorizedClient, "authorizedClient cannot be null"); |
||||
Assert.notNull(principal, "principal cannot be null"); |
||||
|
||||
List<SqlParameterValue> parameters = this.authorizedClientParametersMapper.apply( |
||||
new OAuth2AuthorizedClientHolder(authorizedClient, principal)); |
||||
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); |
||||
|
||||
this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss); |
||||
} |
||||
|
||||
@Override |
||||
public void removeAuthorizedClient(String clientRegistrationId, String principalName) { |
||||
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); |
||||
Assert.hasText(principalName, "principalName cannot be empty"); |
||||
|
||||
SqlParameterValue[] parameters = new SqlParameterValue[] { |
||||
new SqlParameterValue(Types.VARCHAR, clientRegistrationId), |
||||
new SqlParameterValue(Types.VARCHAR, principalName) |
||||
}; |
||||
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); |
||||
|
||||
this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss); |
||||
} |
||||
|
||||
/** |
||||
* Sets the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}. |
||||
* The default is {@link OAuth2AuthorizedClientRowMapper}. |
||||
* |
||||
* @param authorizedClientRowMapper the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient} |
||||
*/ |
||||
public final void setAuthorizedClientRowMapper(RowMapper<OAuth2AuthorizedClient> authorizedClientRowMapper) { |
||||
Assert.notNull(authorizedClientRowMapper, "authorizedClientRowMapper cannot be null"); |
||||
this.authorizedClientRowMapper = authorizedClientRowMapper; |
||||
} |
||||
|
||||
/** |
||||
* Sets the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue}. |
||||
* The default is {@link OAuth2AuthorizedClientParametersMapper}. |
||||
* |
||||
* @param authorizedClientParametersMapper the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue} |
||||
*/ |
||||
public final void setAuthorizedClientParametersMapper(Function<OAuth2AuthorizedClientHolder, List<SqlParameterValue>> authorizedClientParametersMapper) { |
||||
Assert.notNull(authorizedClientParametersMapper, "authorizedClientParametersMapper cannot be null"); |
||||
this.authorizedClientParametersMapper = authorizedClientParametersMapper; |
||||
} |
||||
|
||||
/** |
||||
* The default {@link RowMapper} that maps the current row |
||||
* in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}. |
||||
*/ |
||||
public static class OAuth2AuthorizedClientRowMapper implements RowMapper<OAuth2AuthorizedClient> { |
||||
protected final ClientRegistrationRepository clientRegistrationRepository; |
||||
|
||||
public OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) { |
||||
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); |
||||
this.clientRegistrationRepository = clientRegistrationRepository; |
||||
} |
||||
|
||||
@Override |
||||
public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException { |
||||
String clientRegistrationId = rs.getString("client_registration_id"); |
||||
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId( |
||||
clientRegistrationId); |
||||
if (clientRegistration == null) { |
||||
throw new DataRetrievalFailureException("The ClientRegistration with id '" + |
||||
clientRegistrationId + "' exists in the data source, " + |
||||
"however, it was not found in the ClientRegistrationRepository."); |
||||
} |
||||
|
||||
OAuth2AccessToken.TokenType tokenType = null; |
||||
if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( |
||||
rs.getString("access_token_type"))) { |
||||
tokenType = OAuth2AccessToken.TokenType.BEARER; |
||||
} |
||||
String tokenValue = new String(rs.getBytes("access_token_value"), StandardCharsets.UTF_8); |
||||
Instant issuedAt = rs.getTimestamp("access_token_issued_at").toInstant(); |
||||
Instant expiresAt = rs.getTimestamp("access_token_expires_at").toInstant(); |
||||
Set<String> scopes = Collections.emptySet(); |
||||
String accessTokenScopes = rs.getString("access_token_scopes"); |
||||
if (accessTokenScopes != null) { |
||||
scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); |
||||
} |
||||
OAuth2AccessToken accessToken = new OAuth2AccessToken( |
||||
tokenType, tokenValue, issuedAt, expiresAt, scopes); |
||||
|
||||
OAuth2RefreshToken refreshToken = null; |
||||
byte[] refreshTokenValue = rs.getBytes("refresh_token_value"); |
||||
if (refreshTokenValue != null) { |
||||
tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8); |
||||
issuedAt = null; |
||||
Timestamp refreshTokenIssuedAt = rs.getTimestamp("refresh_token_issued_at"); |
||||
if (refreshTokenIssuedAt != null) { |
||||
issuedAt = refreshTokenIssuedAt.toInstant(); |
||||
} |
||||
refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt); |
||||
} |
||||
|
||||
String principalName = rs.getString("principal_name"); |
||||
|
||||
return new OAuth2AuthorizedClient( |
||||
clientRegistration, principalName, accessToken, refreshToken); |
||||
} |
||||
} |
||||
|
||||
/** |
||||
* The default {@code Function} that maps {@link OAuth2AuthorizedClientHolder} |
||||
* to a {@code List} of {@link SqlParameterValue}. |
||||
*/ |
||||
public static class OAuth2AuthorizedClientParametersMapper implements Function<OAuth2AuthorizedClientHolder, List<SqlParameterValue>> { |
||||
|
||||
@Override |
||||
public List<SqlParameterValue> apply(OAuth2AuthorizedClientHolder authorizedClientHolder) { |
||||
OAuth2AuthorizedClient authorizedClient = authorizedClientHolder.getAuthorizedClient(); |
||||
Authentication principal = authorizedClientHolder.getPrincipal(); |
||||
ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); |
||||
OAuth2AccessToken accessToken = authorizedClient.getAccessToken(); |
||||
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken(); |
||||
|
||||
List<SqlParameterValue> parameters = new ArrayList<>(); |
||||
parameters.add(new SqlParameterValue( |
||||
Types.VARCHAR, clientRegistration.getRegistrationId())); |
||||
parameters.add(new SqlParameterValue( |
||||
Types.VARCHAR, principal.getName())); |
||||
parameters.add(new SqlParameterValue( |
||||
Types.VARCHAR, accessToken.getTokenType().getValue())); |
||||
parameters.add(new SqlParameterValue( |
||||
Types.BLOB, accessToken.getTokenValue().getBytes(StandardCharsets.UTF_8))); |
||||
parameters.add(new SqlParameterValue( |
||||
Types.TIMESTAMP, Timestamp.from(accessToken.getIssuedAt()))); |
||||
parameters.add(new SqlParameterValue( |
||||
Types.TIMESTAMP, Timestamp.from(accessToken.getExpiresAt()))); |
||||
String accessTokenScopes = null; |
||||
if (!CollectionUtils.isEmpty(accessToken.getScopes())) { |
||||
accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getScopes(), ","); |
||||
} |
||||
parameters.add(new SqlParameterValue( |
||||
Types.VARCHAR, accessTokenScopes)); |
||||
byte[] refreshTokenValue = null; |
||||
Timestamp refreshTokenIssuedAt = null; |
||||
if (refreshToken != null) { |
||||
refreshTokenValue = refreshToken.getTokenValue().getBytes(StandardCharsets.UTF_8); |
||||
if (refreshToken.getIssuedAt() != null) { |
||||
refreshTokenIssuedAt = Timestamp.from(refreshToken.getIssuedAt()); |
||||
} |
||||
} |
||||
parameters.add(new SqlParameterValue( |
||||
Types.BLOB, refreshTokenValue)); |
||||
parameters.add(new SqlParameterValue( |
||||
Types.TIMESTAMP, refreshTokenIssuedAt)); |
||||
|
||||
return parameters; |
||||
} |
||||
} |
||||
|
||||
/** |
||||
* A holder for an {@link OAuth2AuthorizedClient} and End-User {@link Authentication} (Resource Owner). |
||||
*/ |
||||
public static final class OAuth2AuthorizedClientHolder { |
||||
private final OAuth2AuthorizedClient authorizedClient; |
||||
private final Authentication principal; |
||||
|
||||
/** |
||||
* Constructs an {@code OAuth2AuthorizedClientHolder} using the provided parameters. |
||||
* |
||||
* @param authorizedClient the authorized client |
||||
* @param principal the End-User {@link Authentication} (Resource Owner) |
||||
*/ |
||||
public OAuth2AuthorizedClientHolder(OAuth2AuthorizedClient authorizedClient, Authentication principal) { |
||||
Assert.notNull(authorizedClient, "authorizedClient cannot be null"); |
||||
Assert.notNull(principal, "principal cannot be null"); |
||||
this.authorizedClient = authorizedClient; |
||||
this.principal = principal; |
||||
} |
||||
|
||||
/** |
||||
* Returns the {@link OAuth2AuthorizedClient}. |
||||
* |
||||
* @return the {@link OAuth2AuthorizedClient} |
||||
*/ |
||||
public OAuth2AuthorizedClient getAuthorizedClient() { |
||||
return this.authorizedClient; |
||||
} |
||||
|
||||
/** |
||||
* Returns the End-User {@link Authentication} (Resource Owner). |
||||
* |
||||
* @return the End-User {@link Authentication} (Resource Owner) |
||||
*/ |
||||
public Authentication getPrincipal() { |
||||
return this.principal; |
||||
} |
||||
} |
||||
} |
||||
@ -0,0 +1,13 @@
@@ -0,0 +1,13 @@
|
||||
CREATE TABLE oauth2_authorized_client ( |
||||
client_registration_id varchar(100) NOT NULL, |
||||
principal_name varchar(200) NOT NULL, |
||||
access_token_type varchar(100) NOT NULL, |
||||
access_token_value blob NOT NULL, |
||||
access_token_issued_at timestamp NOT NULL, |
||||
access_token_expires_at timestamp NOT NULL, |
||||
access_token_scopes varchar(1000) DEFAULT NULL, |
||||
refresh_token_value blob DEFAULT NULL, |
||||
refresh_token_issued_at timestamp DEFAULT NULL, |
||||
created_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, |
||||
PRIMARY KEY (client_registration_id, principal_name) |
||||
); |
||||
@ -0,0 +1,474 @@
@@ -0,0 +1,474 @@
|
||||
/* |
||||
* Copyright 2002-2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.oauth2.client; |
||||
|
||||
import org.junit.After; |
||||
import org.junit.Before; |
||||
import org.junit.Test; |
||||
import org.springframework.dao.DataRetrievalFailureException; |
||||
import org.springframework.dao.DuplicateKeyException; |
||||
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; |
||||
import org.springframework.jdbc.core.JdbcOperations; |
||||
import org.springframework.jdbc.core.JdbcTemplate; |
||||
import org.springframework.jdbc.core.PreparedStatementSetter; |
||||
import org.springframework.jdbc.core.RowMapper; |
||||
import org.springframework.jdbc.core.SqlParameterValue; |
||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; |
||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; |
||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; |
||||
import org.springframework.security.authentication.TestingAuthenticationToken; |
||||
import org.springframework.security.core.Authentication; |
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration; |
||||
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; |
||||
import org.springframework.security.oauth2.client.registration.TestClientRegistrations; |
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken; |
||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken; |
||||
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; |
||||
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; |
||||
import org.springframework.util.Assert; |
||||
import org.springframework.util.StringUtils; |
||||
|
||||
import java.nio.charset.StandardCharsets; |
||||
import java.sql.ResultSet; |
||||
import java.sql.SQLException; |
||||
import java.sql.Timestamp; |
||||
import java.sql.Types; |
||||
import java.time.Instant; |
||||
import java.util.Collections; |
||||
import java.util.List; |
||||
import java.util.Set; |
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat; |
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy; |
||||
import static org.mockito.ArgumentMatchers.any; |
||||
import static org.mockito.ArgumentMatchers.anyInt; |
||||
import static org.mockito.Mockito.mock; |
||||
import static org.mockito.Mockito.spy; |
||||
import static org.mockito.Mockito.verify; |
||||
import static org.mockito.Mockito.when; |
||||
|
||||
/** |
||||
* Tests for {@link JdbcOAuth2AuthorizedClientService}. |
||||
* |
||||
* @author Joe Grandja |
||||
*/ |
||||
public class JdbcOAuth2AuthorizedClientServiceTests { |
||||
private static final String OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/client/oauth2-client-schema.sql"; |
||||
private static int principalId = 1000; |
||||
private ClientRegistration clientRegistration; |
||||
private ClientRegistrationRepository clientRegistrationRepository; |
||||
private EmbeddedDatabase db; |
||||
private JdbcOperations jdbcOperations; |
||||
private JdbcOAuth2AuthorizedClientService authorizedClientService; |
||||
|
||||
@Before |
||||
public void setUp() { |
||||
this.clientRegistration = TestClientRegistrations.clientRegistration().build(); |
||||
this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); |
||||
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.clientRegistration); |
||||
this.db = createDb(); |
||||
this.jdbcOperations = new JdbcTemplate(this.db); |
||||
this.authorizedClientService = new JdbcOAuth2AuthorizedClientService( |
||||
this.jdbcOperations, this.clientRegistrationRepository); |
||||
} |
||||
|
||||
@After |
||||
public void tearDown() { |
||||
this.db.shutdown(); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> new JdbcOAuth2AuthorizedClientService(null, this.clientRegistrationRepository)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("jdbcOperations cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> new JdbcOAuth2AuthorizedClientService(this.jdbcOperations, null)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("clientRegistrationRepository cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void setAuthorizedClientRowMapperWhenNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("authorizedClientRowMapper cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void setAuthorizedClientParametersMapperWhenNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("authorizedClientParametersMapper cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName")) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("clientRegistrationId cannot be empty"); |
||||
} |
||||
|
||||
@Test |
||||
public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("principalName cannot be empty"); |
||||
} |
||||
|
||||
@Test |
||||
public void loadAuthorizedClientWhenDoesNotExistThenReturnNull() { |
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( |
||||
"registration-not-found", "principalName"); |
||||
assertThat(authorizedClient).isNull(); |
||||
} |
||||
|
||||
@Test |
||||
public void loadAuthorizedClientWhenExistsThenReturnAuthorizedClient() { |
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
this.authorizedClientService.saveAuthorizedClient(expected, principal); |
||||
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( |
||||
this.clientRegistration.getRegistrationId(), principal.getName()); |
||||
|
||||
assertThat(authorizedClient).isNotNull(); |
||||
assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); |
||||
assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); |
||||
assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt()); |
||||
assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt()); |
||||
assertThat(authorizedClient.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes()); |
||||
assertThat(authorizedClient.getRefreshToken().getTokenValue()).isEqualTo(expected.getRefreshToken().getTokenValue()); |
||||
assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt()); |
||||
} |
||||
|
||||
@Test |
||||
public void loadAuthorizedClientWhenExistsButNotFoundInClientRegistrationRepositoryThenThrowDataRetrievalFailureException() { |
||||
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(null); |
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
this.authorizedClientService.saveAuthorizedClient(expected, principal); |
||||
|
||||
assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())) |
||||
.isInstanceOf(DataRetrievalFailureException.class) |
||||
.hasMessage("The ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + |
||||
"' exists in the data source, however, it was not found in the ClientRegistrationRepository."); |
||||
} |
||||
|
||||
@Test |
||||
public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { |
||||
Authentication principal = createPrincipal(); |
||||
|
||||
assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, principal)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("authorizedClient cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void saveAuthorizedClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { |
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, null)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("principal cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void saveAuthorizedClientWhenSaveThenLoadReturnsSaved() { |
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
this.authorizedClientService.saveAuthorizedClient(expected, principal); |
||||
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( |
||||
this.clientRegistration.getRegistrationId(), principal.getName()); |
||||
|
||||
assertThat(authorizedClient).isNotNull(); |
||||
assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); |
||||
assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); |
||||
assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt()); |
||||
assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt()); |
||||
assertThat(authorizedClient.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes()); |
||||
assertThat(authorizedClient.getRefreshToken().getTokenValue()).isEqualTo(expected.getRefreshToken().getTokenValue()); |
||||
assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt()); |
||||
|
||||
// Test save/load of NOT NULL attributes only
|
||||
principal = createPrincipal(); |
||||
expected = createAuthorizedClient(principal, this.clientRegistration, true); |
||||
|
||||
this.authorizedClientService.saveAuthorizedClient(expected, principal); |
||||
|
||||
authorizedClient = this.authorizedClientService.loadAuthorizedClient( |
||||
this.clientRegistration.getRegistrationId(), principal.getName()); |
||||
|
||||
assertThat(authorizedClient).isNotNull(); |
||||
assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); |
||||
assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); |
||||
assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); |
||||
assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt()); |
||||
assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt()); |
||||
assertThat(authorizedClient.getAccessToken().getScopes()).isEmpty(); |
||||
assertThat(authorizedClient.getRefreshToken()).isNull(); |
||||
} |
||||
|
||||
@Test |
||||
public void saveAuthorizedClientWhenSaveDuplicateThenThrowDuplicateKeyException() { |
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); |
||||
|
||||
assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal)) |
||||
.isInstanceOf(DuplicateKeyException.class); |
||||
} |
||||
|
||||
@Test |
||||
public void saveLoadAuthorizedClientWhenCustomStrategiesSetThenCalled() throws Exception { |
||||
JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper authorizedClientRowMapper = |
||||
spy(new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper(this.clientRegistrationRepository)); |
||||
this.authorizedClientService.setAuthorizedClientRowMapper(authorizedClientRowMapper); |
||||
JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper authorizedClientParametersMapper = |
||||
spy(new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper()); |
||||
this.authorizedClientService.setAuthorizedClientParametersMapper(authorizedClientParametersMapper); |
||||
|
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); |
||||
this.authorizedClientService.loadAuthorizedClient( |
||||
this.clientRegistration.getRegistrationId(), principal.getName()); |
||||
|
||||
verify(authorizedClientRowMapper).mapRow(any(), anyInt()); |
||||
verify(authorizedClientParametersMapper).apply(any()); |
||||
} |
||||
|
||||
@Test |
||||
public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, "principalName")) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("clientRegistrationId cannot be empty"); |
||||
} |
||||
|
||||
@Test |
||||
public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("principalName cannot be empty"); |
||||
} |
||||
|
||||
@Test |
||||
public void removeAuthorizedClientWhenExistsThenRemoved() { |
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); |
||||
|
||||
authorizedClient = this.authorizedClientService.loadAuthorizedClient( |
||||
this.clientRegistration.getRegistrationId(), principal.getName()); |
||||
assertThat(authorizedClient).isNotNull(); |
||||
|
||||
this.authorizedClientService.removeAuthorizedClient( |
||||
this.clientRegistration.getRegistrationId(), principal.getName()); |
||||
|
||||
authorizedClient = this.authorizedClientService.loadAuthorizedClient( |
||||
this.clientRegistration.getRegistrationId(), principal.getName()); |
||||
assertThat(authorizedClient).isNull(); |
||||
} |
||||
|
||||
@Test |
||||
public void tableDefinitionWhenCustomThenAbleToOverride() { |
||||
CustomTableDefinitionJdbcOAuth2AuthorizedClientService customAuthorizedClientService = |
||||
new CustomTableDefinitionJdbcOAuth2AuthorizedClientService( |
||||
new JdbcTemplate(createDb("custom-oauth2-client-schema.sql")), |
||||
this.clientRegistrationRepository); |
||||
|
||||
Authentication principal = createPrincipal(); |
||||
OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); |
||||
|
||||
customAuthorizedClientService.saveAuthorizedClient(authorizedClient, principal); |
||||
|
||||
authorizedClient = customAuthorizedClientService.loadAuthorizedClient( |
||||
this.clientRegistration.getRegistrationId(), principal.getName()); |
||||
assertThat(authorizedClient).isNotNull(); |
||||
|
||||
customAuthorizedClientService.removeAuthorizedClient( |
||||
this.clientRegistration.getRegistrationId(), principal.getName()); |
||||
|
||||
authorizedClient = customAuthorizedClientService.loadAuthorizedClient( |
||||
this.clientRegistration.getRegistrationId(), principal.getName()); |
||||
assertThat(authorizedClient).isNull(); |
||||
} |
||||
|
||||
private static EmbeddedDatabase createDb() { |
||||
return createDb(OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE); |
||||
} |
||||
|
||||
private static EmbeddedDatabase createDb(String schema) { |
||||
return new EmbeddedDatabaseBuilder() |
||||
.generateUniqueName(true) |
||||
.setType(EmbeddedDatabaseType.HSQL) |
||||
.setScriptEncoding("UTF-8") |
||||
.addScript(schema) |
||||
.build(); |
||||
} |
||||
|
||||
private static Authentication createPrincipal() { |
||||
return new TestingAuthenticationToken("principal-" + principalId++, "password"); |
||||
} |
||||
|
||||
private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, ClientRegistration clientRegistration) { |
||||
return createAuthorizedClient(principal, clientRegistration, false); |
||||
} |
||||
|
||||
private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, |
||||
ClientRegistration clientRegistration, boolean requiredAttributesOnly) { |
||||
OAuth2AccessToken accessToken; |
||||
if (!requiredAttributesOnly) { |
||||
accessToken = TestOAuth2AccessTokens.scopes("read", "write"); |
||||
} else { |
||||
accessToken = TestOAuth2AccessTokens.noScopes(); |
||||
} |
||||
OAuth2RefreshToken refreshToken = null; |
||||
if (!requiredAttributesOnly) { |
||||
refreshToken = TestOAuth2RefreshTokens.refreshToken(); |
||||
} |
||||
return new OAuth2AuthorizedClient( |
||||
clientRegistration, principal.getName(), accessToken, refreshToken); |
||||
} |
||||
|
||||
private static class CustomTableDefinitionJdbcOAuth2AuthorizedClientService extends JdbcOAuth2AuthorizedClientService { |
||||
private static final String COLUMN_NAMES = |
||||
"clientRegistrationId, " + |
||||
"principalName, " + |
||||
"accessTokenType, " + |
||||
"accessTokenValue, " + |
||||
"accessTokenIssuedAt, " + |
||||
"accessTokenExpiresAt, " + |
||||
"accessTokenScopes, " + |
||||
"refreshTokenValue, " + |
||||
"refreshTokenIssuedAt"; |
||||
private static final String TABLE_NAME = "oauth2AuthorizedClient"; |
||||
private static final String PK_FILTER = "clientRegistrationId = ? AND principalName = ?"; |
||||
private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + |
||||
" FROM " + TABLE_NAME + " WHERE " + PK_FILTER; |
||||
private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + |
||||
" (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; |
||||
private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + |
||||
" WHERE " + PK_FILTER; |
||||
|
||||
private CustomTableDefinitionJdbcOAuth2AuthorizedClientService( |
||||
JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) { |
||||
super(jdbcOperations, clientRegistrationRepository); |
||||
setAuthorizedClientRowMapper(new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository)); |
||||
} |
||||
|
||||
@Override |
||||
@SuppressWarnings("unchecked") |
||||
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, String principalName) { |
||||
SqlParameterValue[] parameters = new SqlParameterValue[] { |
||||
new SqlParameterValue(Types.VARCHAR, clientRegistrationId), |
||||
new SqlParameterValue(Types.VARCHAR, principalName) |
||||
}; |
||||
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); |
||||
List<OAuth2AuthorizedClient> result = this.jdbcOperations.query( |
||||
LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper); |
||||
return !result.isEmpty() ? (T) result.get(0) : null; |
||||
} |
||||
|
||||
@Override |
||||
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { |
||||
List<SqlParameterValue> parameters = this.authorizedClientParametersMapper.apply( |
||||
new OAuth2AuthorizedClientHolder(authorizedClient, principal)); |
||||
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); |
||||
this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss); |
||||
} |
||||
|
||||
@Override |
||||
public void removeAuthorizedClient(String clientRegistrationId, String principalName) { |
||||
SqlParameterValue[] parameters = new SqlParameterValue[] { |
||||
new SqlParameterValue(Types.VARCHAR, clientRegistrationId), |
||||
new SqlParameterValue(Types.VARCHAR, principalName) |
||||
}; |
||||
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); |
||||
this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss); |
||||
} |
||||
|
||||
private static class OAuth2AuthorizedClientRowMapper implements RowMapper<OAuth2AuthorizedClient> { |
||||
private final ClientRegistrationRepository clientRegistrationRepository; |
||||
|
||||
private OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) { |
||||
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); |
||||
this.clientRegistrationRepository = clientRegistrationRepository; |
||||
} |
||||
|
||||
@Override |
||||
public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException { |
||||
String clientRegistrationId = rs.getString("clientRegistrationId"); |
||||
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId( |
||||
clientRegistrationId); |
||||
if (clientRegistration == null) { |
||||
throw new DataRetrievalFailureException("The ClientRegistration with id '" + |
||||
clientRegistrationId + "' exists in the data source, " + |
||||
"however, it was not found in the ClientRegistrationRepository."); |
||||
} |
||||
|
||||
OAuth2AccessToken.TokenType tokenType = null; |
||||
if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( |
||||
rs.getString("accessTokenType"))) { |
||||
tokenType = OAuth2AccessToken.TokenType.BEARER; |
||||
} |
||||
String tokenValue = new String(rs.getBytes("accessTokenValue"), StandardCharsets.UTF_8); |
||||
Instant issuedAt = rs.getTimestamp("accessTokenIssuedAt").toInstant(); |
||||
Instant expiresAt = rs.getTimestamp("accessTokenExpiresAt").toInstant(); |
||||
Set<String> scopes = Collections.emptySet(); |
||||
String accessTokenScopes = rs.getString("accessTokenScopes"); |
||||
if (accessTokenScopes != null) { |
||||
scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); |
||||
} |
||||
OAuth2AccessToken accessToken = new OAuth2AccessToken( |
||||
tokenType, tokenValue, issuedAt, expiresAt, scopes); |
||||
|
||||
OAuth2RefreshToken refreshToken = null; |
||||
byte[] refreshTokenValue = rs.getBytes("refreshTokenValue"); |
||||
if (refreshTokenValue != null) { |
||||
tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8); |
||||
issuedAt = null; |
||||
Timestamp refreshTokenIssuedAt = rs.getTimestamp("refreshTokenIssuedAt"); |
||||
if (refreshTokenIssuedAt != null) { |
||||
issuedAt = refreshTokenIssuedAt.toInstant(); |
||||
} |
||||
refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt); |
||||
} |
||||
|
||||
String principalName = rs.getString("principalName"); |
||||
|
||||
return new OAuth2AuthorizedClient( |
||||
clientRegistration, principalName, accessToken, refreshToken); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
@ -0,0 +1,13 @@
@@ -0,0 +1,13 @@
|
||||
CREATE TABLE oauth2AuthorizedClient ( |
||||
clientRegistrationId varchar(100) NOT NULL, |
||||
principalName varchar(200) NOT NULL, |
||||
accessTokenType varchar(100) NOT NULL, |
||||
accessTokenValue blob NOT NULL, |
||||
accessTokenIssuedAt timestamp NOT NULL, |
||||
accessTokenExpiresAt timestamp NOT NULL, |
||||
accessTokenScopes varchar(1000) DEFAULT NULL, |
||||
refreshTokenValue blob DEFAULT NULL, |
||||
refreshTokenIssuedAt timestamp DEFAULT NULL, |
||||
createdAt timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, |
||||
PRIMARY KEY (clientRegistrationId, principalName) |
||||
); |
||||
Loading…
Reference in new issue