From 99fb4c8a5f0cf519b3400e7faa5a9ada6d8519b1 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Wed, 30 Jun 2021 12:08:35 -0500 Subject: [PATCH] Add test to override schema for JdbcOAuth2AuthorizationService --- .../JdbcOAuth2AuthorizationServiceTests.java | 319 ++++++++++++++++++ ...om-oauth2-authorization-consent-schema.sql | 6 + 2 files changed, 325 insertions(+) create mode 100644 oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-consent-schema.sql diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java index b2927834..202afb0f 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java @@ -15,32 +15,48 @@ */ package org.springframework.security.oauth2.server.authorization; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.function.Function; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.springframework.dao.DataRetrievalFailureException; +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.PreparedStatementSetter; import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.SqlParameterValue; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken2; import org.springframework.security.oauth2.core.OAuth2TokenType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -59,6 +75,7 @@ import static org.mockito.Mockito.when; */ public class JdbcOAuth2AuthorizationServiceTests { private static final String OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql"; + private static final String CUSTOM_OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema.sql"; private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE); private static final String ID = "id"; @@ -374,6 +391,30 @@ public class JdbcOAuth2AuthorizationServiceTests { assertThat(result).isNull(); } + @Test + public void tableDefinitionWhenCustomThenAbleToOverride() { + when(this.registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + + EmbeddedDatabase db = createDb(CUSTOM_OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE); + OAuth2AuthorizationService authorizationService = + new CustomJdbcOAuth2AuthorizationService(new JdbcTemplate(db), this.registeredClientRepository); + String state = "state"; + OAuth2Authorization originalAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .attribute(OAuth2ParameterNames.STATE, state) + .token(AUTHORIZATION_CODE) + .build(); + authorizationService.save(originalAuthorization); + OAuth2Authorization foundAuthorization1 = authorizationService.findById(originalAuthorization.getId()); + assertThat(foundAuthorization1).isEqualTo(originalAuthorization); + OAuth2Authorization foundAuthorization2 = authorizationService.findByToken(state, STATE_TOKEN_TYPE); + assertThat(foundAuthorization2).isEqualTo(originalAuthorization); + db.shutdown(); + } + private static EmbeddedDatabase createDb() { return createDb(OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE); } @@ -388,4 +429,282 @@ public class JdbcOAuth2AuthorizationServiceTests { .build(); // @formatter:on } + + private static final class CustomJdbcOAuth2AuthorizationService extends JdbcOAuth2AuthorizationService { + + // @formatter:off + private static final String COLUMN_NAMES = "id, " + + "registeredClientId, " + + "principalName, " + + "authorizationGrantType, " + + "attributes, " + + "state, " + + "authorizationCodeValue, " + + "authorizationCodeIssuedAt, " + + "authorizationCodeExpiresAt," + + "authorizationCodeMetadata," + + "accessTokenValue," + + "accessTokenIssuedAt," + + "accessTokenExpiresAt," + + "accessTokenMetadata," + + "accessTokenType," + + "accessTokenScopes," + + "oidcIdTokenValue," + + "oidcIdTokenIssuedAt," + + "oidcIdTokenExpiresAt," + + "oidcIdTokenMetadata," + + "refreshTokenValue," + + "refreshTokenIssuedAt," + + "refreshTokenExpiresAt," + + "refreshTokenMetadata"; + // @formatter:on + + private static final String TABLE_NAME = "oauth2Authorization"; + + private static final String PK_FILTER = "id = ?"; + private static final String UNKNOWN_TOKEN_TYPE_FILTER = "state = ? OR authorizationCodeValue = ? OR " + + "accessTokenValue = ? OR " + + "refreshTokenValue = ?"; + + // @formatter:off + private static final String LOAD_AUTHORIZATION_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + + " WHERE "; + // @formatter:on + + // @formatter:off + private static final String SAVE_AUTHORIZATION_SQL = "INSERT INTO " + TABLE_NAME + + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?,?, ?, ?, ?, ?, ?, ?, ?,?, ?, ?, ?, ?, ?, ?, ?)"; + // @formatter:on + + private static final String REMOVE_AUTHORIZATION_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER; + + CustomJdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations, + RegisteredClientRepository registeredClientRepository) { + super(jdbcOperations, registeredClientRepository); + setAuthorizationRowMapper(new CustomOAuth2AuthorizationRowMapper(registeredClientRepository)); + setAuthorizationParametersMapper(new CustomOAuth2AuthorizationParametersMapper()); + } + + @Override + public void save(OAuth2Authorization authorization) { + List parameters = getAuthorizationParametersMapper().apply(authorization); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + getJdbcOperations().update(SAVE_AUTHORIZATION_SQL, pss); + } + + @Override + public void remove(OAuth2Authorization authorization) { + SqlParameterValue[] parameters = new SqlParameterValue[] { + new SqlParameterValue(Types.VARCHAR, authorization.getId()) + }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + getJdbcOperations().update(REMOVE_AUTHORIZATION_SQL, pss); + } + + @Override + public OAuth2Authorization findById(String id) { + return findBy(PK_FILTER, id); + } + + @Override + public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) { + return findBy(UNKNOWN_TOKEN_TYPE_FILTER, token, token, token, token); + } + + private OAuth2Authorization findBy(String filter, Object... args) { + List result = getJdbcOperations() + .query(LOAD_AUTHORIZATION_SQL + filter, getAuthorizationRowMapper(), args); + return !result.isEmpty() ? result.get(0) : null; + } + + private static final class CustomOAuth2AuthorizationRowMapper extends JdbcOAuth2AuthorizationService.OAuth2AuthorizationRowMapper { + + CustomOAuth2AuthorizationRowMapper(RegisteredClientRepository registeredClientRepository) { + super(registeredClientRepository); + } + + @Override + @SuppressWarnings("unchecked") + public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException { + String registeredClientId = rs.getString("registeredClientId"); + RegisteredClient registeredClient = getRegisteredClientRepository().findById(registeredClientId); + if (registeredClient == null) { + throw new DataRetrievalFailureException( + "The RegisteredClient with id '" + registeredClientId + "' was not found in the RegisteredClientRepository."); + } + + OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient); + String id = rs.getString("id"); + String principalName = rs.getString("principalName"); + String authorizationGrantType = rs.getString("authorizationGrantType"); + Map attributes = parseMap(rs.getString("attributes")); + + builder.id(id) + .principalName(principalName) + .authorizationGrantType(new AuthorizationGrantType(authorizationGrantType)) + .attributes((attrs) -> attrs.putAll(attributes)); + + String state = rs.getString("state"); + if (StringUtils.hasText(state)) { + builder.attribute(OAuth2ParameterNames.STATE, state); + } + + String tokenValue = rs.getString("authorizationCodeValue"); + Instant tokenIssuedAt; + Instant tokenExpiresAt; + if (tokenValue != null) { + tokenIssuedAt = rs.getTimestamp("authorizationCodeIssuedAt").toInstant(); + tokenExpiresAt = rs.getTimestamp("authorizationCodeExpiresAt").toInstant(); + Map authorizationCodeMetadata = parseMap(rs.getString("authorizationCodeMetadata")); + + OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( + tokenValue, tokenIssuedAt, tokenExpiresAt); + builder.token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata)); + } + + tokenValue = rs.getString("accessTokenValue"); + if (tokenValue != null) { + tokenIssuedAt = rs.getTimestamp("accessTokenIssuedAt").toInstant(); + tokenExpiresAt = rs.getTimestamp("accessTokenExpiresAt").toInstant(); + Map accessTokenMetadata = parseMap(rs.getString("accessTokenMetadata")); + OAuth2AccessToken.TokenType tokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("accessTokenType"))) { + tokenType = OAuth2AccessToken.TokenType.BEARER; + } + + Set scopes = Collections.emptySet(); + String accessTokenScopes = rs.getString("accessTokenScopes"); + if (accessTokenScopes != null) { + scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); + } + OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, tokenIssuedAt, tokenExpiresAt, scopes); + builder.token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata)); + } + + tokenValue = rs.getString("oidcIdTokenValue"); + if (tokenValue != null) { + tokenIssuedAt = rs.getTimestamp("oidcIdTokenIssuedAt").toInstant(); + tokenExpiresAt = rs.getTimestamp("oidcIdTokenExpiresAt").toInstant(); + Map oidcTokenMetadata = parseMap(rs.getString("oidcIdTokenMetadata")); + + OidcIdToken oidcToken = new OidcIdToken( + tokenValue, tokenIssuedAt, tokenExpiresAt, (Map) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME)); + builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata)); + } + + tokenValue = rs.getString("refreshTokenValue"); + if (tokenValue != null) { + tokenIssuedAt = rs.getTimestamp("refreshTokenIssuedAt").toInstant(); + tokenExpiresAt = null; + Timestamp refreshTokenExpiresAt = rs.getTimestamp("refreshTokenExpiresAt"); + if (refreshTokenExpiresAt != null) { + tokenExpiresAt = refreshTokenExpiresAt.toInstant(); + } + Map refreshTokenMetadata = parseMap(rs.getString("refreshTokenMetadata")); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken2( + tokenValue, tokenIssuedAt, tokenExpiresAt); + builder.token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata)); + } + + return builder.build(); + } + + private Map parseMap(String data) { + try { + return getObjectMapper().readValue(data, new TypeReference>() {}); + } catch (Exception ex) { + throw new IllegalArgumentException(ex.getMessage(), ex); + } + } + + } + + private static final class CustomOAuth2AuthorizationParametersMapper extends JdbcOAuth2AuthorizationService.OAuth2AuthorizationParametersMapper { + + @Override + public List apply(OAuth2Authorization authorization) { + List parameters = new ArrayList<>(); + parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getId())); + parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getRegisteredClientId())); + parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getPrincipalName())); + parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getAuthorizationGrantType().getValue())); + + String attributes = writeMap(authorization.getAttributes()); + parameters.add(new SqlParameterValue(Types.VARCHAR, attributes)); + + String state = null; + String authorizationState = authorization.getAttribute(OAuth2ParameterNames.STATE); + if (StringUtils.hasText(authorizationState)) { + state = authorizationState; + } + parameters.add(new SqlParameterValue(Types.VARCHAR, state)); + + OAuth2Authorization.Token authorizationCode = + authorization.getToken(OAuth2AuthorizationCode.class); + List authorizationCodeSqlParameters = toSqlParameterList(authorizationCode); + parameters.addAll(authorizationCodeSqlParameters); + + OAuth2Authorization.Token accessToken = + authorization.getToken(OAuth2AccessToken.class); + List accessTokenSqlParameters = toSqlParameterList(accessToken); + parameters.addAll(accessTokenSqlParameters); + String accessTokenType = null; + String accessTokenScopes = null; + if (accessToken != null) { + accessTokenType = accessToken.getToken().getTokenType().getValue(); + if (!CollectionUtils.isEmpty(accessToken.getToken().getScopes())) { + accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), ","); + } + } + parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenType)); + parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenScopes)); + + OAuth2Authorization.Token oidcIdToken = authorization.getToken(OidcIdToken.class); + List oidcIdTokenSqlParameters = toSqlParameterList(oidcIdToken); + parameters.addAll(oidcIdTokenSqlParameters); + + OAuth2Authorization.Token refreshToken = authorization.getRefreshToken(); + List refreshTokenSqlParameters = toSqlParameterList(refreshToken); + parameters.addAll(refreshTokenSqlParameters); + return parameters; + } + + private List toSqlParameterList(OAuth2Authorization.Token token) { + List parameters = new ArrayList<>(); + String tokenValue = null; + Timestamp tokenIssuedAt = null; + Timestamp tokenExpiresAt = null; + String metadata = null; + if (token != null) { + tokenValue = token.getToken().getTokenValue(); + if (token.getToken().getIssuedAt() != null) { + tokenIssuedAt = Timestamp.from(token.getToken().getIssuedAt()); + } + + if (token.getToken().getExpiresAt() != null) { + tokenExpiresAt = Timestamp.from(token.getToken().getExpiresAt()); + } + metadata = writeMap(token.getMetadata()); + } + parameters.add(new SqlParameterValue(Types.VARCHAR, tokenValue)); + parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenIssuedAt)); + parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenExpiresAt)); + parameters.add(new SqlParameterValue(Types.VARCHAR, metadata)); + return parameters; + } + + private String writeMap(Map data) { + try { + return getObjectMapper().writeValueAsString(data); + } catch (Exception ex) { + throw new IllegalArgumentException(ex.getMessage(), ex); + } + } + + } + + } + } diff --git a/oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-consent-schema.sql b/oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-consent-schema.sql new file mode 100644 index 00000000..3020828a --- /dev/null +++ b/oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-consent-schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE oauth2_authorization_consent ( + registered_client_id varchar(100) NOT NULL, + principal_name varchar(200) NOT NULL, + authorities varchar(1000) NOT NULL, + PRIMARY KEY (registered_client_id, principal_name) +);