Browse Source

Polish gh-291

pull/334/head
Steve Riesenberg 5 years ago
parent
commit
3318874da1
  1. 261
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java
  2. 12
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java

261
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java

@ -30,7 +30,7 @@ import java.util.Map; @@ -30,7 +30,7 @@ import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
@ -88,21 +88,10 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor @@ -88,21 +88,10 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
* @param jdbcOperations the JDBC operations
*/
public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations) {
this(jdbcOperations, new ObjectMapper());
}
/**
* Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters.
*
* @param jdbcOperations the JDBC operations
* @param objectMapper the object mapper
*/
public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, ObjectMapper objectMapper) {
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
Assert.notNull(objectMapper, "objectMapper cannot be null");
this.jdbcOperations = jdbcOperations;
this.registeredClientRowMapper = new DefaultRegisteredClientRowMapper(objectMapper);
this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper(objectMapper);
this.registeredClientRowMapper = new DefaultRegisteredClientRowMapper();
this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper();
}
/**
@ -110,7 +99,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor @@ -110,7 +99,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
*
* @param registeredClientRowMapper mapper implementation
*/
public void setRegisteredClientRowMapper(RowMapper<RegisteredClient> registeredClientRowMapper) {
public final void setRegisteredClientRowMapper(RowMapper<RegisteredClient> registeredClientRowMapper) {
Assert.notNull(registeredClientRowMapper, "registeredClientRowMapper cannot be null");
this.registeredClientRowMapper = registeredClientRowMapper;
}
@ -120,18 +109,30 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor @@ -120,18 +109,30 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
*
* @param registeredClientParametersMapper mapper implementation
*/
public void setRegisteredClientParametersMapper(Function<RegisteredClient, List<SqlParameterValue>> registeredClientParametersMapper) {
public final void setRegisteredClientParametersMapper(Function<RegisteredClient, List<SqlParameterValue>> registeredClientParametersMapper) {
Assert.notNull(registeredClientParametersMapper, "registeredClientParameterMapper cannot be null");
this.registeredClientParametersMapper = registeredClientParametersMapper;
}
protected final JdbcOperations getJdbcOperations() {
return this.jdbcOperations;
}
protected final RowMapper<RegisteredClient> getRegisteredClientRowMapper() {
return this.registeredClientRowMapper;
}
protected final Function<RegisteredClient, List<SqlParameterValue>> getRegisteredClientParametersMapper() {
return this.registeredClientParametersMapper;
}
@Override
public void save(RegisteredClient registeredClient) {
Assert.notNull(registeredClient, "registeredClient cannot be null");
RegisteredClient foundClient = this.findBy("id = ? OR client_id = ?",
RegisteredClient foundClient = findBy("id = ? OR client_id = ?",
registeredClient.getId(), registeredClient.getClientId());
if (null != foundClient) {
if (foundClient != null) {
Assert.isTrue(!foundClient.getId().equals(registeredClient.getId()),
"Registered client must be unique. Found duplicate identifier: " + registeredClient.getId());
Assert.isTrue(!foundClient.getClientId().equals(registeredClient.getClientId()),
@ -155,29 +156,20 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor @@ -155,29 +156,20 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
return findBy("client_id = ?", clientId);
}
private RegisteredClient findBy(String condStr, Object...args) {
List<RegisteredClient> lst = this.jdbcOperations.query(
private RegisteredClient findBy(String condStr, Object... args) {
List<RegisteredClient> result = this.jdbcOperations.query(
LOAD_REGISTERED_CLIENT_SQL + condStr,
registeredClientRowMapper, args);
return !lst.isEmpty() ? lst.get(0) : null;
this.registeredClientRowMapper, args);
return !result.isEmpty() ? result.get(0) : null;
}
public static class DefaultRegisteredClientRowMapper implements RowMapper<RegisteredClient> {
private final ObjectMapper objectMapper;
public DefaultRegisteredClientRowMapper(ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
}
private Set<String> parseList(String s) {
return s != null ? StringUtils.commaDelimitedListToSet(s) : Collections.emptySet();
}
private ObjectMapper objectMapper = new ObjectMapper();
@Override
@SuppressWarnings("unchecked")
public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException {
Set<String> scopes = parseList(rs.getString("scopes"));
Set<String> clientScopes = parseList(rs.getString("scopes"));
Set<String> authGrantTypes = parseList(rs.getString("authorization_grant_types"));
Set<String> clientAuthMethods = parseList(rs.getString("client_authentication_methods"));
Set<String> redirectUris = parseList(rs.getString("redirect_uris"));
@ -191,115 +183,140 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor @@ -191,115 +183,140 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
.clientSecret(clientSecret)
.clientSecretExpiresAt(clientSecretExpiresAt != null ? clientSecretExpiresAt.toInstant() : null)
.clientName(rs.getString("client_name"))
.authorizationGrantTypes(coll -> authGrantTypes.forEach(authGrantType ->
coll.add(AUTHORIZATION_GRANT_TYPE_MAP.get(authGrantType))))
.clientAuthenticationMethods(coll -> clientAuthMethods.forEach(clientAuthMethod ->
coll.add(CLIENT_AUTHENTICATION_METHOD_MAP.get(clientAuthMethod))))
.redirectUris(coll -> coll.addAll(redirectUris))
.scopes(coll -> coll.addAll(scopes));
RegisteredClient rc = builder.build();
.authorizationGrantTypes((grantTypes) -> authGrantTypes.forEach(authGrantType ->
grantTypes.add(AUTHORIZATION_GRANT_TYPE_MAP.get(authGrantType))))
.clientAuthenticationMethods((authenticationMethods) -> clientAuthMethods.forEach(clientAuthMethod ->
authenticationMethods.add(CLIENT_AUTHENTICATION_METHOD_MAP.get(clientAuthMethod))))
.redirectUris((uris) -> uris.addAll(redirectUris))
.scopes((scopes) -> scopes.addAll(clientScopes));
RegisteredClient registeredClient = builder.build();
String tokenSettingsJson = rs.getString("token_settings");
if (tokenSettingsJson != null) {
Map<String, Object> settings = parseMap(tokenSettingsJson);
TokenSettings tokenSettings = registeredClient.getTokenSettings();
Number accessTokenTTL = (Number) settings.get("access_token_ttl");
if (accessTokenTTL != null) {
tokenSettings.accessTokenTimeToLive(Duration.ofMillis(accessTokenTTL.longValue()));
}
TokenSettings ts = rc.getTokenSettings();
ClientSettings cs = rc.getClientSettings();
Number refreshTokenTTL = (Number) settings.get("refresh_token_ttl");
if (refreshTokenTTL != null) {
tokenSettings.refreshTokenTimeToLive(Duration.ofMillis(refreshTokenTTL.longValue()));
}
try {
String tokenSettingsJson = rs.getString("token_settings");
if (tokenSettingsJson != null) {
Map<String, Object> m = this.objectMapper.readValue(tokenSettingsJson, Map.class);
Number accessTokenTTL = (Number) m.get("access_token_ttl");
if (accessTokenTTL != null) {
ts.accessTokenTimeToLive(Duration.ofMillis(accessTokenTTL.longValue()));
}
Number refreshTokenTTL = (Number) m.get("refresh_token_ttl");
if (refreshTokenTTL != null) {
ts.refreshTokenTimeToLive(Duration.ofMillis(refreshTokenTTL.longValue()));
}
Boolean reuseRefreshTokens = (Boolean) m.get("reuse_refresh_tokens");
if (reuseRefreshTokens != null) {
ts.reuseRefreshTokens(reuseRefreshTokens);
}
Boolean reuseRefreshTokens = (Boolean) settings.get("reuse_refresh_tokens");
if (reuseRefreshTokens != null) {
tokenSettings.reuseRefreshTokens(reuseRefreshTokens);
}
}
String clientSettingsJson = rs.getString("client_settings");
if (clientSettingsJson != null) {
Map<String, Object> m = this.objectMapper.readValue(clientSettingsJson, Map.class);
String clientSettingsJson = rs.getString("client_settings");
if (clientSettingsJson != null) {
Map<String, Object> settings = parseMap(clientSettingsJson);
ClientSettings clientSettings = registeredClient.getClientSettings();
Boolean requireProofKey = (Boolean) m.get("require_proof_key");
if (requireProofKey != null) {
cs.requireProofKey(requireProofKey);
}
Boolean requireProofKey = (Boolean) settings.get("require_proof_key");
if (requireProofKey != null) {
clientSettings.requireProofKey(requireProofKey);
}
Boolean requireUserConsent = (Boolean) m.get("require_user_consent");
if (requireUserConsent != null) {
cs.requireUserConsent(requireUserConsent);
}
Boolean requireUserConsent = (Boolean) settings.get("require_user_consent");
if (requireUserConsent != null) {
clientSettings.requireUserConsent(requireUserConsent);
}
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e.getMessage(), e);
}
return rc;
return registeredClient;
}
}
public final void setObjectMapper(ObjectMapper objectMapper) {
Assert.notNull(objectMapper, "objectMapper cannot be null");
this.objectMapper = objectMapper;
}
public static class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> {
protected final ObjectMapper getObjectMapper() {
return this.objectMapper;
}
private final ObjectMapper objectMapper;
private Set<String> parseList(String s) {
return s != null ? StringUtils.commaDelimitedListToSet(s) : Collections.emptySet();
}
private DefaultRegisteredClientParametersMapper(ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
private Map<String, Object> parseMap(String data) {
try {
return this.objectMapper.readValue(data, new TypeReference<Map<String, Object>>() {});
} catch (Exception ex) {
throw new IllegalArgumentException(ex.getMessage(), ex);
}
}
}
public static class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> {
private ObjectMapper objectMapper = new ObjectMapper();
@Override
public List<SqlParameterValue> apply(RegisteredClient registeredClient) {
try {
List<String> clientAuthenticationMethodNames = new ArrayList<>(registeredClient.getClientAuthenticationMethods().size());
for (ClientAuthenticationMethod clientAuthenticationMethod : registeredClient.getClientAuthenticationMethods()) {
clientAuthenticationMethodNames.add(clientAuthenticationMethod.getValue());
}
List<String> clientAuthenticationMethodNames = new ArrayList<>(registeredClient.getClientAuthenticationMethods().size());
for (ClientAuthenticationMethod clientAuthenticationMethod : registeredClient.getClientAuthenticationMethods()) {
clientAuthenticationMethodNames.add(clientAuthenticationMethod.getValue());
}
List<String> authorizationGrantTypeNames = new ArrayList<>(registeredClient.getAuthorizationGrantTypes().size());
for (AuthorizationGrantType authorizationGrantType : registeredClient.getAuthorizationGrantTypes()) {
authorizationGrantTypeNames.add(authorizationGrantType.getValue());
}
List<String> authorizationGrantTypeNames = new ArrayList<>(registeredClient.getAuthorizationGrantTypes().size());
for (AuthorizationGrantType authorizationGrantType : registeredClient.getAuthorizationGrantTypes()) {
authorizationGrantTypeNames.add(authorizationGrantType.getValue());
}
Instant issuedAt = registeredClient.getClientIdIssuedAt() != null ?
registeredClient.getClientIdIssuedAt() : Instant.now();
Timestamp clientSecretExpiresAt = registeredClient.getClientSecretExpiresAt() != null ?
Timestamp.from(registeredClient.getClientSecretExpiresAt()) : null;
Map<String, Object> clientSettings = new HashMap<>();
clientSettings.put("require_proof_key", registeredClient.getClientSettings().requireProofKey());
clientSettings.put("require_user_consent", registeredClient.getClientSettings().requireUserConsent());
String clientSettingsJson = writeMap(clientSettings);
Map<String, Object> tokenSettings = new HashMap<>();
tokenSettings.put("access_token_ttl", registeredClient.getTokenSettings().accessTokenTimeToLive().toMillis());
tokenSettings.put("reuse_refresh_tokens", registeredClient.getTokenSettings().reuseRefreshTokens());
tokenSettings.put("refresh_token_ttl", registeredClient.getTokenSettings().refreshTokenTimeToLive().toMillis());
String tokenSettingsJson = writeMap(tokenSettings);
return Arrays.asList(
new SqlParameterValue(Types.VARCHAR, registeredClient.getId()),
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()),
new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(issuedAt)),
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientSecret()),
new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt),
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()),
new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethodNames)),
new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(authorizationGrantTypeNames)),
new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getRedirectUris())),
new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getScopes())),
new SqlParameterValue(Types.VARCHAR, clientSettingsJson),
new SqlParameterValue(Types.VARCHAR, tokenSettingsJson));
}
Instant issuedAt = registeredClient.getClientIdIssuedAt() != null ?
registeredClient.getClientIdIssuedAt() : Instant.now();
Timestamp clientSecretExpiresAt = registeredClient.getClientSecretExpiresAt() != null ?
Timestamp.from(registeredClient.getClientSecretExpiresAt()) : null;
Map<String, Object> clientSettings = new HashMap<>();
clientSettings.put("require_proof_key", registeredClient.getClientSettings().requireProofKey());
clientSettings.put("require_user_consent", registeredClient.getClientSettings().requireUserConsent());
String clientSettingsJson = this.objectMapper.writeValueAsString(clientSettings);
Map<String, Object> tokenSettings = new HashMap<>();
tokenSettings.put("access_token_ttl", registeredClient.getTokenSettings().accessTokenTimeToLive().toMillis());
tokenSettings.put("reuse_refresh_tokens", registeredClient.getTokenSettings().reuseRefreshTokens());
tokenSettings.put("refresh_token_ttl", registeredClient.getTokenSettings().refreshTokenTimeToLive().toMillis());
String tokenSettingsJson = this.objectMapper.writeValueAsString(tokenSettings);
return Arrays.asList(
new SqlParameterValue(Types.VARCHAR, registeredClient.getId()),
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()),
new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(issuedAt)),
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientSecret()),
new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt),
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()),
new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethodNames)),
new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(authorizationGrantTypeNames)),
new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getRedirectUris())),
new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getScopes())),
new SqlParameterValue(Types.VARCHAR, clientSettingsJson),
new SqlParameterValue(Types.VARCHAR, tokenSettingsJson));
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e.getMessage(), e);
public final void setObjectMapper(ObjectMapper objectMapper) {
Assert.notNull(objectMapper, "objectMapper cannot be null");
this.objectMapper = objectMapper;
}
protected final ObjectMapper getObjectMapper() {
return this.objectMapper;
}
private String writeMap(Map<String, Object> data) {
try {
return this.objectMapper.writeValueAsString(data);
} catch (Exception ex) {
throw new IllegalArgumentException(ex.getMessage(), ex);
}
}

12
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java

@ -20,7 +20,6 @@ import java.nio.charset.Charset; @@ -20,7 +20,6 @@ import java.nio.charset.Charset;
import java.time.Duration;
import java.time.Instant;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@ -89,20 +88,11 @@ public class JdbcRegisteredClientRepositoryTests { @@ -89,20 +88,11 @@ public class JdbcRegisteredClientRepositoryTests {
public void whenJdbcOperationsNullThenThrow() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> new JdbcRegisteredClientRepository(null, new ObjectMapper()))
.isThrownBy(() -> new JdbcRegisteredClientRepository(null))
.withMessage("jdbcOperations cannot be null");
// @formatter:on
}
@Test
public void whenObjectMapperNullThenThrow() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> new JdbcRegisteredClientRepository(this.jdbc, null))
.withMessage("objectMapper cannot be null");
// @formatter:on
}
@Test
public void whenSetNullRegisteredClientRowMapperThenThrow() {
// @formatter:off

Loading…
Cancel
Save