Browse Source

Fix NPE saving public client

Closes gh-326
pull/331/head
Steve Riesenberg 5 years ago committed by Steve Riesenberg
parent
commit
67e62a2f21
  1. 80
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java
  2. 2
      oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql
  3. 37
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java

80
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java

@ -15,8 +15,6 @@
*/ */
package org.springframework.security.oauth2.server.authorization.client; package org.springframework.security.oauth2.server.authorization.client;
import java.nio.charset.StandardCharsets;
import java.sql.PreparedStatement;
import java.sql.ResultSet; import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.sql.Timestamp; import java.sql.Timestamp;
@ -40,9 +38,6 @@ import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.PreparedStatementSetter; import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SqlParameterValue; import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.jdbc.support.lob.DefaultLobHandler;
import org.springframework.jdbc.support.lob.LobCreator;
import org.springframework.jdbc.support.lob.LobHandler;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.server.authorization.config.ClientSettings; import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
@ -87,8 +82,6 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
private final JdbcOperations jdbcOperations; private final JdbcOperations jdbcOperations;
private final LobHandler lobHandler;
/** /**
* Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters. * Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters.
* *
@ -105,25 +98,10 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
* @param objectMapper the object mapper * @param objectMapper the object mapper
*/ */
public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, ObjectMapper objectMapper) { public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, ObjectMapper objectMapper) {
this(jdbcOperations, new DefaultLobHandler(), objectMapper);
}
/**
* Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters.
*
* @param jdbcOperations the JDBC operations
* @param lobHandler the handler for large binary fields and large text fields
* @param objectMapper the object mapper
*/
public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, LobHandler lobHandler, ObjectMapper objectMapper) {
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
Assert.notNull(lobHandler, "lobHandler cannot be null");
Assert.notNull(objectMapper, "objectMapper cannot be null"); Assert.notNull(objectMapper, "objectMapper cannot be null");
this.jdbcOperations = jdbcOperations; this.jdbcOperations = jdbcOperations;
this.lobHandler = lobHandler; this.registeredClientRowMapper = new DefaultRegisteredClientRowMapper(objectMapper);
DefaultRegisteredClientRowMapper registeredClientRowMapper = new DefaultRegisteredClientRowMapper(objectMapper);
registeredClientRowMapper.setLobHandler(lobHandler);
this.registeredClientRowMapper = registeredClientRowMapper;
this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper(objectMapper); this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper(objectMapper);
} }
@ -150,25 +128,19 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
@Override @Override
public void save(RegisteredClient registeredClient) { public void save(RegisteredClient registeredClient) {
Assert.notNull(registeredClient, "registeredClient cannot be null"); Assert.notNull(registeredClient, "registeredClient cannot be null");
RegisteredClient foundClient = this.findBy("id = ? OR client_id = ? OR client_secret = ?", RegisteredClient foundClient = this.findBy("id = ? OR client_id = ?",
registeredClient.getId(), registeredClient.getClientId(), registeredClient.getId(), registeredClient.getClientId());
registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8));
if (null != foundClient) { if (null != foundClient) {
Assert.isTrue(!foundClient.getId().equals(registeredClient.getId()), Assert.isTrue(!foundClient.getId().equals(registeredClient.getId()),
"Registered client must be unique. Found duplicate identifier: " + registeredClient.getId()); "Registered client must be unique. Found duplicate identifier: " + registeredClient.getId());
Assert.isTrue(!foundClient.getClientId().equals(registeredClient.getClientId()), Assert.isTrue(!foundClient.getClientId().equals(registeredClient.getClientId()),
"Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId()); "Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId());
Assert.isTrue(!foundClient.getClientSecret().equals(registeredClient.getClientSecret()),
"Registered client must be unique. Found duplicate client secret for identifier: " + registeredClient.getId());
} }
List<SqlParameterValue> parameters = this.registeredClientParametersMapper.apply(registeredClient); List<SqlParameterValue> parameters = this.registeredClientParametersMapper.apply(registeredClient);
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
try (LobCreator lobCreator = this.lobHandler.getLobCreator()) { this.jdbcOperations.update(INSERT_REGISTERED_CLIENT_SQL, pss);
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, parameters.toArray());
jdbcOperations.update(INSERT_REGISTERED_CLIENT_SQL, pss);
}
} }
@Override @Override
@ -184,7 +156,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
} }
private RegisteredClient findBy(String condStr, Object...args) { private RegisteredClient findBy(String condStr, Object...args) {
List<RegisteredClient> lst = jdbcOperations.query( List<RegisteredClient> lst = this.jdbcOperations.query(
LOAD_REGISTERED_CLIENT_SQL + condStr, LOAD_REGISTERED_CLIENT_SQL + condStr,
registeredClientRowMapper, args); registeredClientRowMapper, args);
return !lst.isEmpty() ? lst.get(0) : null; return !lst.isEmpty() ? lst.get(0) : null;
@ -194,8 +166,6 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
private final ObjectMapper objectMapper; private final ObjectMapper objectMapper;
private LobHandler lobHandler = new DefaultLobHandler();
public DefaultRegisteredClientRowMapper(ObjectMapper objectMapper) { public DefaultRegisteredClientRowMapper(ObjectMapper objectMapper) {
this.objectMapper = objectMapper; this.objectMapper = objectMapper;
} }
@ -213,8 +183,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
Set<String> redirectUris = parseList(rs.getString("redirect_uris")); Set<String> redirectUris = parseList(rs.getString("redirect_uris"));
Timestamp clientIssuedAt = rs.getTimestamp("client_id_issued_at"); Timestamp clientIssuedAt = rs.getTimestamp("client_id_issued_at");
Timestamp clientSecretExpiresAt = rs.getTimestamp("client_secret_expires_at"); Timestamp clientSecretExpiresAt = rs.getTimestamp("client_secret_expires_at");
byte[] clientSecretBytes = this.lobHandler.getBlobAsBytes(rs, "client_secret"); String clientSecret = rs.getString("client_secret");
String clientSecret = clientSecretBytes != null ? new String(clientSecretBytes, StandardCharsets.UTF_8) : null;
RegisteredClient.Builder builder = RegisteredClient RegisteredClient.Builder builder = RegisteredClient
.withId(rs.getString("id")) .withId(rs.getString("id"))
.clientId(rs.getString("client_id")) .clientId(rs.getString("client_id"))
@ -276,11 +245,6 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
return rc; return rc;
} }
public final void setLobHandler(LobHandler lobHandler) {
Assert.notNull(lobHandler, "lobHandler cannot be null");
this.lobHandler = lobHandler;
}
} }
public static class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> { public static class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> {
@ -325,7 +289,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
new SqlParameterValue(Types.VARCHAR, registeredClient.getId()), new SqlParameterValue(Types.VARCHAR, registeredClient.getId()),
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()), new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()),
new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(issuedAt)), new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(issuedAt)),
new SqlParameterValue(Types.BLOB, registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8)), new SqlParameterValue(Types.VARCHAR, registeredClient.getClientSecret()),
new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt), new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt),
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()), new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()),
new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethodNames)), new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethodNames)),
@ -341,34 +305,6 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
} }
private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
protected final LobCreator lobCreator;
private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) {
super(args);
this.lobCreator = lobCreator;
}
@Override
protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException {
if (argValue instanceof SqlParameterValue) {
SqlParameterValue paramValue = (SqlParameterValue) argValue;
if (paramValue.getSqlType() == Types.BLOB) {
if (paramValue.getValue() != null) {
Assert.isInstanceOf(byte[].class, paramValue.getValue(),
"Value of blob parameter must be byte[]");
}
byte[] valueBytes = (byte[]) paramValue.getValue();
this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes);
return;
}
}
super.doSetValue(ps, parameterPosition, argValue);
}
}
static { static {
Map<String, AuthorizationGrantType> am = new HashMap<>(); Map<String, AuthorizationGrantType> am = new HashMap<>();
for (AuthorizationGrantType a : Arrays.asList( for (AuthorizationGrantType a : Arrays.asList(

2
oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql

@ -2,7 +2,7 @@ CREATE TABLE oauth2_registered_client (
id varchar(100) NOT NULL, id varchar(100) NOT NULL,
client_id varchar(100) NOT NULL, client_id varchar(100) NOT NULL,
client_id_issued_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, client_id_issued_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
client_secret blob NOT NULL, client_secret varchar(200) DEFAULT NULL,
client_secret_expires_at timestamp DEFAULT NULL, client_secret_expires_at timestamp DEFAULT NULL,
client_name varchar(200), client_name varchar(200),
client_authentication_methods varchar(1000) NOT NULL, client_authentication_methods varchar(1000) NOT NULL,

37
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java

@ -103,15 +103,6 @@ public class JdbcRegisteredClientRepositoryTests {
// @formatter:on // @formatter:on
} }
@Test
public void whenLobHandlerNullThenThrow() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> new JdbcRegisteredClientRepository(this.jdbc, null, new ObjectMapper()))
.withMessage("lobHandler cannot be null");
// @formatter:on
}
@Test @Test
public void whenSetNullRegisteredClientRowMapperThenThrow() { public void whenSetNullRegisteredClientRowMapperThenThrow() {
// @formatter:off // @formatter:off
@ -198,12 +189,12 @@ public class JdbcRegisteredClientRepositoryTests {
} }
@Test @Test
public void saveWhenExistingClientSecretThenThrowIllegalArgumentException() { public void saveWhenExistingClientSecretThenSuccess() {
RegisteredClient registeredClient = createRegisteredClient( RegisteredClient registeredClient = createRegisteredClient(
"client-2", "client-id-2", this.registration.getClientSecret()); "client-2", "client-id-2", this.registration.getClientSecret());
assertThatIllegalArgumentException() this.clients.save(registeredClient);
.isThrownBy(() -> this.clients.save(registeredClient)) RegisteredClient savedClient = this.clients.findById(registeredClient.getId());
.withMessage("Registered client must be unique. Found duplicate client secret for identifier: " + registeredClient.getId()); assertRegisteredClientIsEqualTo(savedClient, registeredClient);
} }
@Test @Test
@ -222,6 +213,26 @@ public class JdbcRegisteredClientRepositoryTests {
assertRegisteredClientIsEqualTo(savedClient, registeredClient); assertRegisteredClientIsEqualTo(savedClient, registeredClient);
} }
@Test
public void saveWhenPublicClientSavedAndFindByClientIdThenFound() {
RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
this.clients.save(registeredClient);
RegisteredClient savedClient = this.clients.findByClientId(registeredClient.getClientId());
assertRegisteredClientIsEqualTo(savedClient, registeredClient);
}
@Test
public void saveWhenMultiplePublicClientsSavedAndFindByIdThenFound() {
RegisteredClient registeredClient1 = TestRegisteredClients.registeredPublicClient()
.id("1").clientId("a").build();
RegisteredClient registeredClient2 = TestRegisteredClients.registeredPublicClient()
.id("2").clientId("b").build();
this.clients.save(registeredClient1);
this.clients.save(registeredClient2);
RegisteredClient savedClient = this.clients.findByClientId(registeredClient2.getClientId());
assertRegisteredClientIsEqualTo(savedClient, registeredClient2);
}
@Test @Test
public void whenSaveRegistrationWithAllAttrsThenSaved() { public void whenSaveRegistrationWithAllAttrsThenSaved() {
Instant issuedAt = Instant.now(), expiresAt = issuedAt.plus(Duration.ofDays(30)); Instant issuedAt = Instant.now(), expiresAt = issuedAt.plus(Duration.ofDays(30));

Loading…
Cancel
Save