From 769cf8fac72aa5ed40fddfae82f6cf4f834237b2 Mon Sep 17 00:00:00 2001 From: Rafal Lewczuk Date: Tue, 30 Mar 2021 13:05:52 +0200 Subject: [PATCH] JDBC implementation of RegisteredClientRepository Closes gh-265 --- .../JdbcRegisteredClientRepository.java | 335 ++++++++++++++++++ .../client/oauth2_registered_client.sql | 14 + .../JdbcRegisteredClientRepositoryTests.java | 291 +++++++++++++++ 3 files changed, 640 insertions(+) create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java create mode 100644 oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2_registered_client.sql create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java new file mode 100644 index 00000000..d461e9c7 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java @@ -0,0 +1,335 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.client; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.jdbc.core.*; +import org.springframework.jdbc.support.lob.DefaultLobHandler; +import org.springframework.jdbc.support.lob.LobCreator; +import org.springframework.jdbc.support.lob.LobHandler; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.server.authorization.config.ClientSettings; +import org.springframework.security.oauth2.server.authorization.config.TokenSettings; +import org.springframework.util.Assert; + +import java.nio.charset.StandardCharsets; +import java.sql.*; +import java.time.Duration; +import java.time.Instant; +import java.util.*; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * JDBC-backed registered client repository + * + * @author Rafal Lewczuk + * @since 0.1.2 + */ +public class JdbcRegisteredClientRepository implements RegisteredClientRepository { + + private static final Map AUTHORIZATION_GRANT_TYPE_MAP; + private static final Map CLIENT_AUTHENTICATION_METHOD_MAP; + + private static final String COLUMN_NAMES = "id, " + + "client_id, " + + "client_id_issued_at, " + + "client_secret, " + + "client_secret_expires_at, " + + "client_name, " + + "client_authentication_methods, " + + "authorization_grant_types, " + + "redirect_uris, " + + "scopes, " + + "client_settings," + + "token_settings"; + + private static final String TABLE_NAME = "oauth2_registered_client"; + + private static final String LOAD_REGISTERED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + " WHERE "; + + private static final String INSERT_REGISTERED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + + "(" + COLUMN_NAMES + ") values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + + private RowMapper registeredClientRowMapper; + + private Function> registeredClientParametersMapper; + + private final JdbcOperations jdbcOperations; + + private final LobHandler lobHandler = new DefaultLobHandler(); + + private final ObjectMapper objectMapper; + + public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, ObjectMapper objectMapper) { + Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); + Assert.notNull(objectMapper, "objectMapper cannot be null"); + this.jdbcOperations = jdbcOperations; + this.objectMapper = objectMapper; + this.registeredClientRowMapper = new DefaultRegisteredClientRowMapper(); + this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper(); + } + + /** + * Allows changing of {@link RegisteredClient} row mapper implementation + * + * @param registeredClientRowMapper mapper implementation + */ + public void setRegisteredClientRowMapper(RowMapper registeredClientRowMapper) { + Assert.notNull(registeredClientRowMapper, "registeredClientRowMapper cannot be null"); + this.registeredClientRowMapper = registeredClientRowMapper; + } + + /** + * Allows changing of SQL parameter mapper for {@link RegisteredClient} + * + * @param registeredClientParametersMapper mapper implementation + */ + public void setRegisteredClientParametersMapper(Function> registeredClientParametersMapper) { + Assert.notNull(registeredClientParametersMapper, "registeredClientParameterMapper cannot be null"); + this.registeredClientParametersMapper = registeredClientParametersMapper; + } + + @Override + public void save(RegisteredClient registeredClient) { + Assert.notNull(registeredClient, "registeredClient cannot be null"); + RegisteredClient foundClient = this.findBy("id = ? OR client_id = ? OR client_secret = ?", + registeredClient.getId(), registeredClient.getClientId(), + registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8)); + + if (null != foundClient) { + Assert.isTrue(!foundClient.getId().equals(registeredClient.getId()), + "Registered client must be unique. Found duplicate identifier: " + registeredClient.getId()); + Assert.isTrue(!foundClient.getClientId().equals(registeredClient.getClientId()), + "Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId()); + Assert.isTrue(!foundClient.getClientSecret().equals(registeredClient.getClientSecret()), + "Registered client must be unique. Found duplicate client secret for identifier: " + registeredClient.getId()); + } + + List parameters = this.registeredClientParametersMapper.apply(registeredClient); + + try (LobCreator lobCreator = this.lobHandler.getLobCreator()) { + PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, parameters.toArray()); + jdbcOperations.update(INSERT_REGISTERED_CLIENT_SQL, pss); + } + } + + @Override + public RegisteredClient findById(String id) { + Assert.hasText(id, "id cannot be empty"); + return findBy("id = ?", id); + } + + @Override + public RegisteredClient findByClientId(String clientId) { + Assert.hasText(clientId, "clientId cannot be empty"); + return findBy("client_id = ?", clientId); + } + + private RegisteredClient findBy(String condStr, Object...args) { + List lst = jdbcOperations.query( + LOAD_REGISTERED_CLIENT_SQL + condStr, + registeredClientRowMapper, args); + return !lst.isEmpty() ? lst.get(0) : null; + } + + private class DefaultRegisteredClientRowMapper implements RowMapper { + + private final LobHandler lobHandler = new DefaultLobHandler(); + + private Collection parseList(String s) { + return s != null ? Arrays.asList(s.split("\\|")) : Collections.emptyList(); + } + + @Override + @SuppressWarnings("unchecked") + public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException { + Collection scopes = parseList(rs.getString("scopes")); + List authGrantTypes = parseList(rs.getString("authorization_grant_types")) + .stream().map(AUTHORIZATION_GRANT_TYPE_MAP::get).collect(Collectors.toList()); + List clientAuthMethods = parseList(rs.getString("client_authentication_methods")) + .stream().map(CLIENT_AUTHENTICATION_METHOD_MAP::get).collect(Collectors.toList()); + Collection redirectUris = parseList(rs.getString("redirect_uris")); + Timestamp clientIssuedAt = rs.getTimestamp("client_id_issued_at"); + Timestamp clientSecretExpiresAt = rs.getTimestamp("client_secret_expires_at"); + byte[] clientSecretBytes = this.lobHandler.getBlobAsBytes(rs, "client_secret"); + String clientSecret = clientSecretBytes != null ? new String(clientSecretBytes, StandardCharsets.UTF_8) : null; + RegisteredClient.Builder builder = RegisteredClient + .withId(rs.getString("id")) + .clientId(rs.getString("client_id")) + .clientIdIssuedAt(clientIssuedAt != null ? clientIssuedAt.toInstant() : null) + .clientSecret(clientSecret) + .clientSecretExpiresAt(clientSecretExpiresAt != null ? clientSecretExpiresAt.toInstant() : null) + .clientName(rs.getString("client_name")) + .clientAuthenticationMethods(coll -> coll.addAll(clientAuthMethods)) + .authorizationGrantTypes(coll -> coll.addAll(authGrantTypes)) + .redirectUris(coll -> coll.addAll(redirectUris)) + .scopes(coll -> coll.addAll(scopes)); + + RegisteredClient rc = builder.build(); + + TokenSettings ts = rc.getTokenSettings(); + ClientSettings cs = rc.getClientSettings(); + + try { + String tokenSettingsJson = rs.getString("token_settings"); + if (tokenSettingsJson != null) { + + Map m = JdbcRegisteredClientRepository.this.objectMapper.readValue(tokenSettingsJson, Map.class); + + Number accessTokenTTL = (Number) m.get("access_token_ttl"); + if (accessTokenTTL != null) { + ts.accessTokenTimeToLive(Duration.ofMillis(accessTokenTTL.longValue())); + } + + Number refreshTokenTTL = (Number) m.get("refresh_token_ttl"); + if (refreshTokenTTL != null) { + ts.refreshTokenTimeToLive(Duration.ofMillis(refreshTokenTTL.longValue())); + } + + Boolean reuseRefreshTokens = (Boolean) m.get("reuse_refresh_tokens"); + if (reuseRefreshTokens != null) { + ts.reuseRefreshTokens(reuseRefreshTokens); + } + } + + String clientSettingsJson = rs.getString("client_settings"); + if (clientSettingsJson != null) { + + Map m = JdbcRegisteredClientRepository.this.objectMapper.readValue(clientSettingsJson, Map.class); + + Boolean requireProofKey = (Boolean) m.get("require_proof_key"); + if (requireProofKey != null) { + cs.requireProofKey(requireProofKey); + } + + Boolean requireUserConsent = (Boolean) m.get("require_user_consent"); + if (requireUserConsent != null) { + cs.requireUserConsent(requireUserConsent); + } + } + + + } catch (JsonProcessingException e) { + throw new IllegalArgumentException(e.getMessage(), e); + } + + return rc; + } + } + + private class DefaultRegisteredClientParametersMapper implements Function> { + @Override + public List apply(RegisteredClient registeredClient) { + try { + List clientAuthenticationMethodNames = new ArrayList<>(registeredClient.getClientAuthenticationMethods().size()); + for (ClientAuthenticationMethod clientAuthenticationMethod : registeredClient.getClientAuthenticationMethods()) { + clientAuthenticationMethodNames.add(clientAuthenticationMethod.getValue()); + } + + List authorizationGrantTypeNames = new ArrayList<>(registeredClient.getAuthorizationGrantTypes().size()); + for (AuthorizationGrantType authorizationGrantType : registeredClient.getAuthorizationGrantTypes()) { + authorizationGrantTypeNames.add(authorizationGrantType.getValue()); + } + + Instant issuedAt = registeredClient.getClientIdIssuedAt() != null ? + registeredClient.getClientIdIssuedAt() : Instant.now(); + + Timestamp clientSecretExpiresAt = registeredClient.getClientSecretExpiresAt() != null ? + Timestamp.from(registeredClient.getClientSecretExpiresAt()) : null; + + Map clientSettings = new HashMap<>(); + clientSettings.put("require_proof_key", registeredClient.getClientSettings().requireProofKey()); + clientSettings.put("require_user_consent", registeredClient.getClientSettings().requireUserConsent()); + String clientSettingsJson = JdbcRegisteredClientRepository.this.objectMapper.writeValueAsString(clientSettings); + + Map tokenSettings = new HashMap<>(); + tokenSettings.put("access_token_ttl", registeredClient.getTokenSettings().accessTokenTimeToLive().toMillis()); + tokenSettings.put("reuse_refresh_tokens", registeredClient.getTokenSettings().reuseRefreshTokens()); + tokenSettings.put("refresh_token_ttl", registeredClient.getTokenSettings().refreshTokenTimeToLive().toMillis()); + String tokenSettingsJson = JdbcRegisteredClientRepository.this.objectMapper.writeValueAsString(tokenSettings); + + return Arrays.asList( + new SqlParameterValue(Types.VARCHAR, registeredClient.getId()), + new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()), + new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(issuedAt)), + new SqlParameterValue(Types.BLOB, registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8)), + new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt), + new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()), + new SqlParameterValue(Types.VARCHAR, String.join("|", clientAuthenticationMethodNames)), + new SqlParameterValue(Types.VARCHAR, String.join("|", authorizationGrantTypeNames)), + new SqlParameterValue(Types.VARCHAR, String.join("|", registeredClient.getRedirectUris())), + new SqlParameterValue(Types.VARCHAR, String.join("|", registeredClient.getScopes())), + new SqlParameterValue(Types.VARCHAR, clientSettingsJson), + new SqlParameterValue(Types.VARCHAR, tokenSettingsJson)); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException(e.getMessage(), e); + } + } + } + + private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter { + + protected final LobCreator lobCreator; + + private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) { + super(args); + this.lobCreator = lobCreator; + } + + @Override + protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException { + if (argValue instanceof SqlParameterValue) { + SqlParameterValue paramValue = (SqlParameterValue) argValue; + if (paramValue.getSqlType() == Types.BLOB) { + if (paramValue.getValue() != null) { + Assert.isInstanceOf(byte[].class, paramValue.getValue(), + "Value of blob parameter must be byte[]"); + } + byte[] valueBytes = (byte[]) paramValue.getValue(); + this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes); + return; + } + } + super.doSetValue(ps, parameterPosition, argValue); + } + + } + + static { + Map am = new HashMap<>(); + for (AuthorizationGrantType a : Arrays.asList( + AuthorizationGrantType.AUTHORIZATION_CODE, + AuthorizationGrantType.REFRESH_TOKEN, + AuthorizationGrantType.CLIENT_CREDENTIALS, + AuthorizationGrantType.PASSWORD, + AuthorizationGrantType.IMPLICIT)) { + am.put(a.getValue(), a); + } + AUTHORIZATION_GRANT_TYPE_MAP = Collections.unmodifiableMap(am); + + Map cm = new HashMap<>(); + for (ClientAuthenticationMethod c : Arrays.asList( + ClientAuthenticationMethod.NONE, + ClientAuthenticationMethod.BASIC, + ClientAuthenticationMethod.POST)) { + cm.put(c.getValue(), c); + } + CLIENT_AUTHENTICATION_METHOD_MAP = Collections.unmodifiableMap(cm); + } +} diff --git a/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2_registered_client.sql b/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2_registered_client.sql new file mode 100644 index 00000000..e0f201b7 --- /dev/null +++ b/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2_registered_client.sql @@ -0,0 +1,14 @@ +CREATE TABLE oauth2_registered_client ( + id varchar(100) NOT NULL, + client_id varchar(100) NOT NULL, + client_id_issued_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, + client_secret blob NOT NULL, + client_secret_expires_at timestamp DEFAULT NULL, + client_name varchar(200), + client_authentication_methods varchar(1000) NOT NULL, + authorization_grant_types varchar(1000) NOT NULL, + redirect_uris varchar(1000) NOT NULL, + scopes varchar(1000) NOT NULL, + client_settings varchar(1000) DEFAULT NULL, + token_settings varchar(1000) DEFAULT NULL, + PRIMARY KEY (id)); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java new file mode 100644 index 00000000..a17d2ee4 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java @@ -0,0 +1,291 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.DriverManagerDataSource; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.util.StreamUtils; + +import java.io.InputStream; +import java.nio.charset.Charset; +import java.time.Duration; +import java.time.Instant; + +import static org.assertj.core.api.Assertions.*; + +/** + * JDBC-backed registered client repository tests + * + * @author Rafal Lewczuk + * @since 0.1.2 + */ +public class JdbcRegisteredClientRepositoryTests { + + private final String SCRIPT = "/org/springframework/security/oauth2/server/authorization/client/oauth2_registered_client.sql"; + + private DriverManagerDataSource dataSource; + + private JdbcRegisteredClientRepository clients; + + private RegisteredClient registration; + + private JdbcTemplate jdbc; + + @Before + public void setup() throws Exception { + this.dataSource = new DriverManagerDataSource(); + this.dataSource.setDriverClassName("org.hsqldb.jdbcDriver"); + this.dataSource.setUrl("jdbc:hsqldb:mem:oauthtest"); + this.dataSource.setUsername("sa"); + this.dataSource.setPassword(""); + + this.jdbc = new JdbcTemplate(this.dataSource); + + // execute scripts + try (InputStream is = JdbcRegisteredClientRepositoryTests.class.getResourceAsStream(SCRIPT)) { + assertThat(is).isNotNull().describedAs("Cannot open resource file: " + SCRIPT); + String ddls = StreamUtils.copyToString(is, Charset.defaultCharset()); + for (String ddl : ddls.split(";\n")) { + if (!ddl.trim().isEmpty()) { + this.jdbc.execute(ddl.trim()); + } + } + } + + this.clients = new JdbcRegisteredClientRepository(this.jdbc, new ObjectMapper()); + this.registration = TestRegisteredClients.registeredClient().build(); + + this.clients.save(this.registration); + } + + @After + public void destroyDatabase() { + this.jdbc.update("truncate table oauth2_registered_client"); + new JdbcTemplate(this.dataSource).execute("SHUTDOWN"); + } + + @Test + public void whenJdbcOperationsNullThenThrow() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new JdbcRegisteredClientRepository(null, new ObjectMapper())) + .withMessage("jdbcOperations cannot be null"); + // @formatter:on + } + + @Test + public void whenObjectMapperNullThenThrow() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new JdbcRegisteredClientRepository(this.jdbc, null)) + .withMessage("objectMapper cannot be null"); + // @formatter:on + } + + @Test + public void whenSetNullRegisteredClientRowMapperThenThrow() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.setRegisteredClientRowMapper(null)) + .withMessage("registeredClientRowMapper cannot be null"); + // @formatter:on + } + + @Test + public void whenSetNullRegisteredClientParameterMapperThenThrow() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.setRegisteredClientParametersMapper(null)) + .withMessage("registeredClientParameterMapper cannot be null"); + // @formatter:on + } + + @Test + public void findByIdWhenFoundThenFound() { + String id = this.registration.getId(); + assertRegisteredClientIsEqualTo(this.clients.findById(id), this.registration); + } + + @Test + public void findByIdWhenNotFoundThenNull() { + RegisteredClient client = this.clients.findById("noooope"); + assertThat(client).isNull(); + } + + @Test + public void findByIdWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.findById(null)) + .withMessage("id cannot be empty"); + // @formatter:on + } + + @Test + public void findByClientIdWhenFoundThenFound() { + String id = this.registration.getClientId(); + assertRegisteredClientIsEqualTo(this.clients.findByClientId(id), this.registration); + } + + @Test + public void findByClientIdWhenNotFoundThenNull() { + RegisteredClient client = this.clients.findByClientId("noooope"); + assertThat(client).isNull(); + } + + @Test + public void findByClientIdWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.findByClientId(null)) + .withMessage("clientId cannot be empty"); + // @formatter:on + } + + @Test + public void saveWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.save(null)) + .withMessageContaining("registeredClient cannot be null"); + } + + @Test + public void saveWhenExistingIdThenThrowIllegalArgumentException() { + RegisteredClient registeredClient = createRegisteredClient( + this.registration.getId(), "client-id-2", "client-secret-2"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.save(registeredClient)) + .withMessage("Registered client must be unique. Found duplicate identifier: " + registeredClient.getId()); + } + + @Test + public void saveWhenExistingClientIdThenThrowIllegalArgumentException() { + RegisteredClient registeredClient = createRegisteredClient( + "client-2", this.registration.getClientId(), "client-secret-2"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.save(registeredClient)) + .withMessage("Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId()); + } + + @Test + public void saveWhenExistingClientSecretThenThrowIllegalArgumentException() { + RegisteredClient registeredClient = createRegisteredClient( + "client-2", "client-id-2", this.registration.getClientSecret()); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.save(registeredClient)) + .withMessage("Registered client must be unique. Found duplicate client secret for identifier: " + registeredClient.getId()); + } + + @Test + public void saveWhenSavedAndFindByIdThenFound() { + RegisteredClient registeredClient = createRegisteredClient(); + this.clients.save(registeredClient); + RegisteredClient savedClient = this.clients.findById(registeredClient.getId()); + assertRegisteredClientIsEqualTo(savedClient, registeredClient); + } + + @Test + public void saveWhenSavedAndFindByClientIdThenFound() { + RegisteredClient registeredClient = createRegisteredClient(); + this.clients.save(registeredClient); + RegisteredClient savedClient = this.clients.findByClientId(registeredClient.getClientId()); + assertRegisteredClientIsEqualTo(savedClient, registeredClient); + } + + @Test + public void whenSaveRegistrationWithAllAttrsThenSaved() { + Instant issuedAt = Instant.now(), expiresAt = issuedAt.plus(Duration.ofDays(30)); + RegisteredClient client = TestRegisteredClients.registeredClient2() + .clientIdIssuedAt(issuedAt) + .clientSecretExpiresAt(expiresAt) + .clientSecret("secret2") + .clientName("some_client_name") + .redirectUri("https://example2.com") + .clientSettings(cs -> { + cs.requireProofKey(true); + cs.requireUserConsent(true); + }) + .tokenSettings(ts -> { + ts.accessTokenTimeToLive(Duration.ofMinutes(3)); + ts.reuseRefreshTokens(true); + ts.refreshTokenTimeToLive(Duration.ofMinutes(300)); + }) + .build(); + + this.clients.save(client); + + RegisteredClient retrievedClient = this.clients.findById(client.getId()); + + assertRegisteredClientIsEqualTo(retrievedClient, client); + } + + private void assertRegisteredClientIsEqualTo(RegisteredClient rc, RegisteredClient ref) { + assertThat(rc).isNotNull(); + assertThat(rc.getId()).isEqualTo(ref.getId()); + assertThat(rc.getClientId()).isEqualTo(ref.getClientId()); + + if (ref.getClientIdIssuedAt() != null) { + // This can be set to default value + Instant inst = ref.getClientIdIssuedAt(); + assertThat(rc.getClientIdIssuedAt()).isBetween(inst.minusMillis(1), inst.plusMillis(1)); + } + + assertThat(rc.getClientSecret()).isEqualTo(ref.getClientSecret()); + + if (ref.getClientSecretExpiresAt() != null) { + Instant inst = ref.getClientSecretExpiresAt(); + assertThat(rc.getClientSecretExpiresAt()).isBetween(inst.minusMillis(1), inst.plusMillis(1)); + } else { + assertThat(rc.getClientSecretExpiresAt()).isNull(); + } + + assertThat(rc.getClientName()).isEqualTo(ref.getClientName()); + assertThat(rc.getClientAuthenticationMethods()).isEqualTo(ref.getClientAuthenticationMethods()); + assertThat(rc.getAuthorizationGrantTypes()).isEqualTo(ref.getAuthorizationGrantTypes()); + assertThat(rc.getRedirectUris()).isEqualTo(ref.getRedirectUris()); + assertThat(rc.getScopes()).isEqualTo(ref.getScopes()); + assertThat(rc.getClientSettings().requireUserConsent()).isEqualTo(ref.getClientSettings().requireUserConsent()); + assertThat(rc.getClientSettings().requireProofKey()).isEqualTo(ref.getClientSettings().requireProofKey()); + assertThat(rc.getTokenSettings().reuseRefreshTokens()).isEqualTo(ref.getTokenSettings().reuseRefreshTokens()); + assertThat(rc.getTokenSettings().accessTokenTimeToLive()).isEqualTo(ref.getTokenSettings().accessTokenTimeToLive()); + assertThat(rc.getTokenSettings().refreshTokenTimeToLive()).isEqualTo(ref.getTokenSettings().refreshTokenTimeToLive()); + } + + private static RegisteredClient createRegisteredClient() { + return createRegisteredClient("client-2", "client-id-2", "client-secret-2"); + } + + + private static RegisteredClient createRegisteredClient(String id, String clientId, String clientSecret) { + // @formatter:off + return RegisteredClient.withId(id) + .clientId(clientId) + .clientSecret(clientSecret) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .redirectUri("https://client.example.com") + .scope("scope1") + .build(); + // @formatter:on + } + +}