@ -16,6 +16,7 @@
@@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization ;
import java.nio.charset.StandardCharsets ;
import java.sql.DatabaseMetaData ;
import java.sql.PreparedStatement ;
import java.sql.ResultSet ;
import java.sql.SQLException ;
@ -35,6 +36,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
@@ -35,6 +36,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.dao.DataRetrievalFailureException ;
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter ;
import org.springframework.jdbc.core.ConnectionCallback ;
import org.springframework.jdbc.core.JdbcOperations ;
import org.springframework.jdbc.core.PreparedStatementSetter ;
import org.springframework.jdbc.core.RowMapper ;
@ -141,6 +143,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
@@ -141,6 +143,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
private final JdbcOperations jdbcOperations ;
private final LobHandler lobHandler ;
private static int tokenColumnType ;
private RowMapper < OAuth2Authorization > authorizationRowMapper ;
private Function < OAuth2Authorization , List < SqlParameterValue > > authorizationParametersMapper ;
@ -169,12 +172,15 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
@@ -169,12 +172,15 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
Assert . notNull ( lobHandler , "lobHandler cannot be null" ) ;
this . jdbcOperations = jdbcOperations ;
this . lobHandler = lobHandler ;
tokenColumnType = getColumnDataType ( jdbcOperations , "access_token_value" ) ;
OAuth2AuthorizationRowMapper authorizationRowMapper = new OAuth2AuthorizationRowMapper ( registeredClientRepository ) ;
authorizationRowMapper . setLobHandler ( lobHandler ) ;
this . authorizationRowMapper = authorizationRowMapper ;
this . authorizationParametersMapper = new OAuth2AuthorizationParametersMapper ( ) ;
OAuth2AuthorizationParametersMapper authorizationParametersMapper = new OAuth2AuthorizationParametersMapper ( ) ;
this . authorizationParametersMapper = authorizationParametersMapper ;
}
@Override
public void save ( OAuth2Authorization authorization ) {
Assert . notNull ( authorization , "authorization cannot be null" ) ;
@ -232,26 +238,33 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
@@ -232,26 +238,33 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
List < SqlParameterValue > parameters = new ArrayList < > ( ) ;
if ( tokenType = = null ) {
parameters . add ( new SqlParameterValue ( Types . VARCHAR , token ) ) ;
parameters . add ( new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ) ) ;
parameters . add ( new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ) ) ;
parameters . add ( new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ) ) ;
parameters . add ( mapTokenToSqlParameter ( token ) ) ;
parameters . add ( mapTokenToSqlParameter ( token ) ) ;
parameters . add ( mapTokenToSqlParameter ( token ) ) ;
return findBy ( UNKNOWN_TOKEN_TYPE_FILTER , parameters ) ;
} else if ( OAuth2ParameterNames . STATE . equals ( tokenType . getValue ( ) ) ) {
parameters . add ( new SqlParameterValue ( Types . VARCHAR , token ) ) ;
return findBy ( STATE_FILTER , parameters ) ;
} else if ( OAuth2ParameterNames . CODE . equals ( tokenType . getValue ( ) ) ) {
parameters . add ( new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ) ) ;
parameters . add ( mapTokenToSqlParameter ( token ) ) ;
return findBy ( AUTHORIZATION_CODE_FILTER , parameters ) ;
} else if ( OAuth2TokenType . ACCESS_TOKEN . equals ( tokenType ) ) {
parameters . add ( new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ) ) ;
parameters . add ( mapTokenToSqlParameter ( token ) ) ;
return findBy ( ACCESS_TOKEN_FILTER , parameters ) ;
} else if ( OAuth2TokenType . REFRESH_TOKEN . equals ( tokenType ) ) {
parameters . add ( new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ) ) ;
parameters . add ( mapTokenToSqlParameter ( token ) ) ;
return findBy ( REFRESH_TOKEN_FILTER , parameters ) ;
}
return null ;
}
private SqlParameterValue mapTokenToSqlParameter ( String token ) {
if ( Types . BLOB = = tokenColumnType ) {
return new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ) ;
}
return new SqlParameterValue ( tokenColumnType , token ) ;
}
private OAuth2Authorization findBy ( String filter , List < SqlParameterValue > parameters ) {
try ( LobCreator lobCreator = getLobHandler ( ) . getLobCreator ( ) ) {
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter ( lobCreator ,
@ -349,25 +362,22 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
@@ -349,25 +362,22 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
builder . attribute ( OAuth2ParameterNames . STATE , state ) ;
}
String tokenValue ;
Instant tokenIssuedAt ;
Instant tokenExpiresAt ;
byte [ ] authorizationCodeValue = this . lobHandler . getBlobAsBytes ( rs , "authorization_code_value" ) ;
String authorizationCodeValue = getTokenValue ( rs , "authorization_code_value" ) ;
if ( authorizationCodeValue ! = null ) {
tokenValue = new String ( authorizationCodeValue , StandardCharsets . UTF_8 ) ;
if ( StringUtils . hasText ( authorizationCodeValue ) ) {
tokenIssuedAt = rs . getTimestamp ( "authorization_code_issued_at" ) . toInstant ( ) ;
tokenExpiresAt = rs . getTimestamp ( "authorization_code_expires_at" ) . toInstant ( ) ;
Map < String , Object > authorizationCodeMetadata = parseMap ( rs . getString ( "authorization_code_metadata" ) ) ;
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode (
token Value, tokenIssuedAt , tokenExpiresAt ) ;
authorizationCode Value, tokenIssuedAt , tokenExpiresAt ) ;
builder . token ( authorizationCode , ( metadata ) - > metadata . putAll ( authorizationCodeMetadata ) ) ;
}
byte [ ] accessTokenValue = this . lobHandler . getBlobAsBytes ( rs , "access_token_value" ) ;
if ( accessTokenValue ! = null ) {
tokenValue = new String ( accessTokenValue , StandardCharsets . UTF_8 ) ;
String accessTokenValue = getTokenValue ( rs , "access_token_value" ) ;
if ( StringUtils . hasText ( accessTokenValue ) ) {
tokenIssuedAt = rs . getTimestamp ( "access_token_issued_at" ) . toInstant ( ) ;
tokenExpiresAt = rs . getTimestamp ( "access_token_expires_at" ) . toInstant ( ) ;
Map < String , Object > accessTokenMetadata = parseMap ( rs . getString ( "access_token_metadata" ) ) ;
@ -381,25 +391,23 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
@@ -381,25 +391,23 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
if ( accessTokenScopes ! = null ) {
scopes = StringUtils . commaDelimitedListToSet ( accessTokenScopes ) ;
}
OAuth2AccessToken accessToken = new OAuth2AccessToken ( tokenType , t okenValue, tokenIssuedAt , tokenExpiresAt , scopes ) ;
OAuth2AccessToken accessToken = new OAuth2AccessToken ( tokenType , accessT okenValue, tokenIssuedAt , tokenExpiresAt , scopes ) ;
builder . token ( accessToken , ( metadata ) - > metadata . putAll ( accessTokenMetadata ) ) ;
}
byte [ ] oidcIdTokenValue = this . lobHandler . getBlobAsBytes ( rs , "oidc_id_token_value" ) ;
if ( oidcIdTokenValue ! = null ) {
tokenValue = new String ( oidcIdTokenValue , StandardCharsets . UTF_8 ) ;
String oidcIdTokenValue = getTokenValue ( rs , "oidc_id_token_value" ) ;
if ( StringUtils . hasText ( oidcIdTokenValue ) ) {
tokenIssuedAt = rs . getTimestamp ( "oidc_id_token_issued_at" ) . toInstant ( ) ;
tokenExpiresAt = rs . getTimestamp ( "oidc_id_token_expires_at" ) . toInstant ( ) ;
Map < String , Object > oidcTokenMetadata = parseMap ( rs . getString ( "oidc_id_token_metadata" ) ) ;
OidcIdToken oidcToken = new OidcIdToken (
t okenValue, tokenIssuedAt , tokenExpiresAt , ( Map < String , Object > ) oidcTokenMetadata . get ( OAuth2Authorization . Token . CLAIMS_METADATA_NAME ) ) ;
oidcIdT okenValue, tokenIssuedAt , tokenExpiresAt , ( Map < String , Object > ) oidcTokenMetadata . get ( OAuth2Authorization . Token . CLAIMS_METADATA_NAME ) ) ;
builder . token ( oidcToken , ( metadata ) - > metadata . putAll ( oidcTokenMetadata ) ) ;
}
byte [ ] refreshTokenValue = this . lobHandler . getBlobAsBytes ( rs , "refresh_token_value" ) ;
if ( refreshTokenValue ! = null ) {
tokenValue = new String ( refreshTokenValue , StandardCharsets . UTF_8 ) ;
String refreshTokenValue = getTokenValue ( rs , "refresh_token_value" ) ;
if ( StringUtils . hasText ( refreshTokenValue ) ) {
tokenIssuedAt = rs . getTimestamp ( "refresh_token_issued_at" ) . toInstant ( ) ;
tokenExpiresAt = null ;
Timestamp refreshTokenExpiresAt = rs . getTimestamp ( "refresh_token_expires_at" ) ;
@ -409,12 +417,29 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
@@ -409,12 +417,29 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
Map < String , Object > refreshTokenMetadata = parseMap ( rs . getString ( "refresh_token_metadata" ) ) ;
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken (
t okenValue, tokenIssuedAt , tokenExpiresAt ) ;
refreshT okenValue, tokenIssuedAt , tokenExpiresAt ) ;
builder . token ( refreshToken , ( metadata ) - > metadata . putAll ( refreshTokenMetadata ) ) ;
}
return builder . build ( ) ;
}
private String getTokenValue ( ResultSet rs , String tokenColumn ) throws SQLException {
String tokenValue = null ;
if ( Types . CLOB = = tokenColumnType ) {
tokenValue = this . lobHandler . getClobAsString ( rs , tokenColumn ) ;
}
if ( Types . VARCHAR = = tokenColumnType ) {
tokenValue = rs . getString ( tokenColumn ) ;
}
if ( Types . BLOB = = tokenColumnType ) {
byte [ ] tokenValueByte = this . lobHandler . getBlobAsBytes ( rs , tokenColumn ) ;
if ( tokenValueByte ! = null ) {
tokenValue = new String ( tokenValueByte , StandardCharsets . UTF_8 ) ;
}
}
return tokenValue ;
}
public final void setLobHandler ( LobHandler lobHandler ) {
Assert . notNull ( lobHandler , "lobHandler cannot be null" ) ;
this . lobHandler = lobHandler ;
@ -520,12 +545,12 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
@@ -520,12 +545,12 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
private < T extends AbstractOAuth2Token > List < SqlParameterValue > toSqlParameterList ( OAuth2Authorization . Token < T > token ) {
List < SqlParameterValue > parameters = new ArrayList < > ( ) ;
byte [ ] tokenValue = null ;
String tokenValue = null ;
Timestamp tokenIssuedAt = null ;
Timestamp tokenExpiresAt = null ;
String metadata = null ;
if ( token ! = null ) {
tokenValue = token . getToken ( ) . getTokenValue ( ) . getBytes ( StandardCharsets . UTF_8 ) ;
tokenValue = token . getToken ( ) . getTokenValue ( ) ;
if ( token . getToken ( ) . getIssuedAt ( ) ! = null ) {
tokenIssuedAt = Timestamp . from ( token . getToken ( ) . getIssuedAt ( ) ) ;
}
@ -534,7 +559,13 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
@@ -534,7 +559,13 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
}
metadata = writeMap ( token . getMetadata ( ) ) ;
}
parameters . add ( new SqlParameterValue ( Types . BLOB , tokenValue ) ) ;
if ( Types . BLOB = = tokenColumnType & & StringUtils . hasText ( tokenValue ) ) {
byte [ ] tokenValueAsBytes = tokenValue . getBytes ( StandardCharsets . UTF_8 ) ;
parameters . add ( new SqlParameterValue ( tokenColumnType , tokenValueAsBytes ) ) ;
} else {
parameters . add ( new SqlParameterValue ( tokenColumnType , tokenValue ) ) ;
}
parameters . add ( new SqlParameterValue ( Types . TIMESTAMP , tokenIssuedAt ) ) ;
parameters . add ( new SqlParameterValue ( Types . TIMESTAMP , tokenExpiresAt ) ) ;
parameters . add ( new SqlParameterValue ( Types . VARCHAR , metadata ) ) ;
@ -551,6 +582,23 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
@@ -551,6 +582,23 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
}
private static int getColumnDataType ( JdbcOperations jdbcOperations , String columnName ) {
return jdbcOperations . execute ( ( ConnectionCallback < Integer > ) con - > {
DatabaseMetaData databaseMetaData = con . getMetaData ( ) ;
ResultSet rs = databaseMetaData . getColumns ( null , null , TABLE_NAME , columnName ) ;
if ( rs . next ( ) ) {
return rs . getInt ( "DATA_TYPE" ) ;
}
// NOTE: When using HSQL: When a database object is created with one of the CREATE statements if the name is enclosed in double quotes, the exact name is used as the case-normal form.
// But if it is not enclosed in double quotes, the name is converted to uppercase and this uppercase version is stored in the database as the case-normal form
rs = databaseMetaData . getColumns ( null , null , TABLE_NAME . toUpperCase ( ) , columnName . toUpperCase ( ) ) ;
if ( rs . next ( ) ) {
return rs . getInt ( "DATA_TYPE" ) ;
}
return Types . NULL ;
} ) ;
}
private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
private final LobCreator lobCreator ;
@ -572,6 +620,15 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
@@ -572,6 +620,15 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
this . lobCreator . setBlobAsBytes ( ps , parameterPosition , valueBytes ) ;
return ;
}
if ( paramValue . getSqlType ( ) = = Types . CLOB ) {
if ( paramValue . getValue ( ) ! = null ) {
Assert . isInstanceOf ( String . class , paramValue . getValue ( ) ,
"Value of clob parameter must be String" ) ;
}
String valueString = ( String ) paramValue . getValue ( ) ;
this . lobCreator . setClobAsString ( ps , parameterPosition , valueString ) ;
return ;
}
}
super . doSetValue ( ps , parameterPosition , argValue ) ;
}