diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java index fbc899d41..bde4720b4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java @@ -15,19 +15,26 @@ */ package org.springframework.data.mongodb.core; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Collectors; -import org.bson.conversions.Bson; +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.Document; import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.query.Collation; +import org.springframework.data.mongodb.core.schema.IdentifiableJsonSchemaProperty; import org.springframework.data.mongodb.core.schema.IdentifiableJsonSchemaProperty.QueryableJsonSchemaProperty; import org.springframework.data.mongodb.core.schema.JsonSchemaProperty; import org.springframework.data.mongodb.core.schema.MongoJsonSchema; import org.springframework.data.mongodb.core.schema.QueryCharacteristics; +import org.springframework.data.mongodb.core.schema.QueryCharacteristics.QueryCharacteristic; import org.springframework.data.mongodb.core.timeseries.Granularity; import org.springframework.data.mongodb.core.timeseries.GranularityDefinition; import org.springframework.data.mongodb.core.validation.Validator; @@ -38,6 +45,7 @@ import org.springframework.util.ObjectUtils; import com.mongodb.client.model.ValidationAction; import com.mongodb.client.model.ValidationLevel; +import org.springframework.util.StringUtils; /** * Provides a simple wrapper to encapsulate the variety of settings you can use when creating a collection. @@ -58,11 +66,12 @@ public class CollectionOptions { private ValidationOptions validationOptions; private @Nullable TimeSeriesOptions timeSeriesOptions; private @Nullable CollectionChangeStreamOptions changeStreamOptions; - private @Nullable Bson encryptedFields; +// private @Nullable Bson encryptedFields; + private @Nullable EncryptedCollectionOptions encryptedCollectionOptions; private CollectionOptions(@Nullable Long size, @Nullable Long maxDocuments, @Nullable Boolean capped, @Nullable Collation collation, ValidationOptions validationOptions, @Nullable TimeSeriesOptions timeSeriesOptions, - @Nullable CollectionChangeStreamOptions changeStreamOptions, @Nullable Bson encryptedFields) { + @Nullable CollectionChangeStreamOptions changeStreamOptions, @Nullable EncryptedCollectionOptions encryptedCollectionOptions) { this.maxDocuments = maxDocuments; this.size = size; @@ -71,7 +80,7 @@ public class CollectionOptions { this.validationOptions = validationOptions; this.timeSeriesOptions = timeSeriesOptions; this.changeStreamOptions = changeStreamOptions; - this.encryptedFields = encryptedFields; + this.encryptedCollectionOptions = encryptedCollectionOptions; } /** @@ -145,7 +154,7 @@ public class CollectionOptions { */ public CollectionOptions capped() { return new CollectionOptions(size, maxDocuments, true, collation, validationOptions, timeSeriesOptions, - changeStreamOptions, encryptedFields); + changeStreamOptions, encryptedCollectionOptions); } /** @@ -157,7 +166,7 @@ public class CollectionOptions { */ public CollectionOptions maxDocuments(long maxDocuments) { return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions, encryptedFields); + changeStreamOptions, encryptedCollectionOptions); } /** @@ -169,7 +178,7 @@ public class CollectionOptions { */ public CollectionOptions size(long size) { return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions, encryptedFields); + changeStreamOptions, encryptedCollectionOptions); } /** @@ -181,7 +190,7 @@ public class CollectionOptions { */ public CollectionOptions collation(@Nullable Collation collation) { return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions, encryptedFields); + changeStreamOptions, encryptedCollectionOptions); } /** @@ -302,7 +311,7 @@ public class CollectionOptions { Assert.notNull(validationOptions, "ValidationOptions must not be null"); return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions, encryptedFields); + changeStreamOptions, encryptedCollectionOptions); } /** @@ -316,7 +325,7 @@ public class CollectionOptions { Assert.notNull(timeSeriesOptions, "TimeSeriesOptions must not be null"); return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions, encryptedFields); + changeStreamOptions, encryptedCollectionOptions); } /** @@ -330,19 +339,25 @@ public class CollectionOptions { Assert.notNull(changeStreamOptions, "ChangeStreamOptions must not be null"); return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions, encryptedFields); + changeStreamOptions, encryptedCollectionOptions); } /** * Create new {@link CollectionOptions} with the given {@code encryptedFields}. * - * @param encryptedFields can be null + * @param encryptedCollectionOptions can be null * @return new instance of {@link CollectionOptions}. * @since 4.5.0 */ - public CollectionOptions encryptedFields(@Nullable Bson encryptedFields) { - return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions, encryptedFields); + public static CollectionOptions encrypted(@Nullable EncryptedCollectionOptions encryptedCollectionOptions) { + return new CollectionOptions(null, null, null, null, null, null, + null, encryptedCollectionOptions); + } + + public static CollectionOptions encrypted(Consumer options) { + EncryptedCollectionOptions theOptions = new EncryptedCollectionOptions(); + options.accept(theOptions); + return encrypted(theOptions); } /** @@ -419,15 +434,15 @@ public class CollectionOptions { * @return {@link Optional#empty()} if not specified. * @since 4.5.0 */ - public Optional getEncryptedFields() { - return Optional.ofNullable(encryptedFields); + public Optional getEncryptedFields() { + return Optional.ofNullable(encryptedCollectionOptions); } @Override public String toString() { return "CollectionOptions{" + "maxDocuments=" + maxDocuments + ", size=" + size + ", capped=" + capped + ", collation=" + collation + ", validationOptions=" + validationOptions + ", timeSeriesOptions=" - + timeSeriesOptions + ", changeStreamOptions=" + changeStreamOptions + ", encryptedFields=" + encryptedFields + + timeSeriesOptions + ", changeStreamOptions=" + changeStreamOptions + ", encryptedCollectionOptions=" + encryptedCollectionOptions + ", disableValidation=" + disableValidation() + ", strictValidation=" + strictValidation() + ", moderateValidation=" + moderateValidation() + ", warnOnValidationError=" + warnOnValidationError() + ", failOnValidationError=" + failOnValidationError() + '}'; @@ -465,7 +480,7 @@ public class CollectionOptions { if (!ObjectUtils.nullSafeEquals(changeStreamOptions, that.changeStreamOptions)) { return false; } - return ObjectUtils.nullSafeEquals(encryptedFields, that.encryptedFields); + return ObjectUtils.nullSafeEquals(encryptedCollectionOptions, that.encryptedCollectionOptions); } @Override @@ -477,7 +492,7 @@ public class CollectionOptions { result = 31 * result + ObjectUtils.nullSafeHashCode(validationOptions); result = 31 * result + ObjectUtils.nullSafeHashCode(timeSeriesOptions); result = 31 * result + ObjectUtils.nullSafeHashCode(changeStreamOptions); - result = 31 * result + ObjectUtils.nullSafeHashCode(encryptedFields); + result = 31 * result + ObjectUtils.nullSafeHashCode(encryptedCollectionOptions); return result; } @@ -615,12 +630,33 @@ public class CollectionOptions { private List queryableProperties = new ArrayList<>(); - public EncryptedCollectionOptions queryable(JsonSchemaProperty schemaObject, QueryCharacteristics characteristics) { + public EncryptedCollectionOptions queryable(JsonSchemaProperty schemaObject, QueryCharacteristic... characteristics) { - queryableProperties.add(JsonSchemaProperty.queryable(schemaObject, characteristics)); + QueryCharacteristics characteristics1 = new QueryCharacteristics(List.of(characteristics)); + queryableProperties.add(JsonSchemaProperty.queryable(schemaObject, characteristics1)); return this; } + + public Document toDocument() { + + + List fields = new ArrayList<>(queryableProperties.size()); + for(QueryableJsonSchemaProperty property : queryableProperties) { + Document field = new Document("path", property.getIdentifier()); + if(!property.getTypes().isEmpty()) { + field.append("bsonType", property.getTypes().iterator().next().toBsonType().value()); + } + if(property.getTargetProperty() instanceof IdentifiableJsonSchemaProperty.EncryptedJsonSchemaProperty encrypted) { + if(StringUtils.hasText(encrypted.getKeyId())) { + new BsonBinary(BsonBinarySubType.UUID_STANDARD, encrypted.getKeyId().getBytes(StandardCharsets.UTF_8)); + } + } + field.append("queries", property.getCharacteristics().getCharacteristics().stream().map(QueryCharacteristic::toDocument).collect(Collectors.toList())); + } + + return new Document("fields", fields); + } } /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java index b7a2380ce..7977b62c3 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java @@ -39,6 +39,7 @@ import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.mapping.PropertyPath; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mapping.model.ConvertingPropertyAccessor; +import org.springframework.data.mongodb.core.CollectionOptions.EncryptedCollectionOptions; import org.springframework.data.mongodb.core.CollectionOptions.TimeSeriesOptions; import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.convert.MongoJsonSchemaMapper; @@ -379,7 +380,11 @@ class EntityOperations { collectionOptions.getChangeStreamOptions().ifPresent(it -> result .changeStreamPreAndPostImagesOptions(new ChangeStreamPreAndPostImagesOptions(it.getPreAndPostImages()))); - collectionOptions.getEncryptedFields().ifPresent(result::encryptedFields); + if(collectionOptions.getEncryptedFields().isPresent()) { + EncryptedCollectionOptions encryptedCollectionOptions = collectionOptions.getEncryptedFields().get(); + + result.encryptedFields(encryptedCollectionOptions.toDocument()); + } return result; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/IdentifiableJsonSchemaProperty.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/IdentifiableJsonSchemaProperty.java index aa66873e6..574553cad 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/IdentifiableJsonSchemaProperty.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/IdentifiableJsonSchemaProperty.java @@ -113,6 +113,18 @@ public class IdentifiableJsonSchemaProperty implemen public Set getTypes() { return targetProperty.getTypes(); } + + boolean isEncrypted() { + return targetProperty instanceof EncryptedJsonSchemaProperty; + } + + public JsonSchemaProperty getTargetProperty() { + return targetProperty; + } + + public QueryCharacteristics getCharacteristics() { + return characteristics; + } } /** @@ -1206,5 +1218,9 @@ public class IdentifiableJsonSchemaProperty implemen return null; } + + public String getKeyId() { + return keyId; + } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/QueryCharacteristics.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/QueryCharacteristics.java index e71792de2..f8940ef82 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/QueryCharacteristics.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/QueryCharacteristics.java @@ -69,7 +69,7 @@ public class QueryCharacteristics { this.characteristics.add(characteristic); } - List getCharacteristics() { + public List getCharacteristics() { return characteristics; } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTests.java index 284cb9228..d271cf7ae 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTests.java @@ -19,10 +19,12 @@ import static java.util.Arrays.*; import static org.assertj.core.api.Assertions.*; import static org.springframework.data.mongodb.core.EncryptionAlgorithms.*; import static org.springframework.data.mongodb.core.query.Criteria.*; +import static org.springframework.data.mongodb.core.schema.JsonSchemaProperty.encrypted; +import static org.springframework.data.mongodb.core.schema.JsonSchemaProperty.int32; +import static org.springframework.data.mongodb.core.schema.JsonSchemaProperty.int64; +import static org.springframework.data.mongodb.core.schema.QueryCharacteristics.range; -import java.nio.charset.StandardCharsets; import java.security.SecureRandom; -import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; @@ -43,7 +45,6 @@ import com.mongodb.client.model.CreateEncryptedCollectionParams; import com.mongodb.client.model.Filters; import com.mongodb.client.model.IndexOptions; import com.mongodb.client.model.Indexes; -import com.mongodb.client.model.vault.DataKeyOptions; import com.mongodb.client.vault.ClientEncryption; import com.mongodb.client.vault.ClientEncryptions; @@ -65,10 +66,14 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.data.convert.PropertyValueConverterFactory; import org.springframework.data.mongodb.config.AbstractMongoClientConfiguration; +import org.springframework.data.mongodb.core.CollectionOptions; +import org.springframework.data.mongodb.core.CollectionOptions.EncryptedCollectionOptions; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.convert.MongoCustomConversions.MongoConverterConfigurationAdapter; import org.springframework.data.mongodb.core.convert.encryption.MongoEncryptionConverter; import org.springframework.data.mongodb.core.mapping.ExplicitEncrypted; +import org.springframework.data.mongodb.core.schema.JsonSchemaProperty; +import org.springframework.data.mongodb.core.schema.QueryCharacteristics; import org.springframework.data.mongodb.test.util.EnableIfMongoServerVersion; import org.springframework.data.mongodb.test.util.EnableIfReplicaSetAvailable; import org.springframework.data.mongodb.test.util.MongoClientExtension; @@ -203,38 +208,49 @@ class RangeEncryptionTests { database.getCollection("test").drop(); ClientEncryption clientEncryption = mongoClientEncryption.getClientEncryption(); -// BsonDocument encryptedFields = new BsonDocument().append("fields", -// new BsonArray(asList( -// new BsonDocument("keyId", BsonNull.VALUE).append("path", new BsonString("encryptedInt")) -// .append("bsonType", new BsonString("int")) -// .append("queries", -// new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) -// .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) -// .append("min", new BsonInt32(0)).append("max", new BsonInt32(200))), -// new BsonDocument("keyId", BsonNull.VALUE).append("path", new BsonString("encryptedLong")) -// .append("bsonType", new BsonString("long")).append("queries", -// new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) -// .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) -// .append("min", new BsonInt64(1000)).append("max", new BsonInt64(9999)))))); - - BsonBinary dataKey1 = clientEncryption.createDataKey(LOCAL_KMS_PROVIDER, new DataKeyOptions().keyAltNames(List.of("dek-1"))); - BsonBinary dataKey2 = clientEncryption.createDataKey(LOCAL_KMS_PROVIDER, new DataKeyOptions().keyAltNames(List.of("dek-2"))); - BsonDocument encryptedFields = new BsonDocument().append("fields", - new BsonArray(asList( - new BsonDocument("keyId", dataKey1).append("path", new BsonString("encryptedInt")) - .append("bsonType", new BsonString("int")) - .append("queries", - new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) - .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) - .append("min", new BsonInt32(0)).append("max", new BsonInt32(200))), - new BsonDocument("keyId", dataKey2).append("path", new BsonString("encryptedLong")) - .append("bsonType", new BsonString("long")).append("queries", - new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) - .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) - .append("min", new BsonInt64(1000)).append("max", new BsonInt64(9999)))))); - - + new BsonArray(asList( + new BsonDocument("keyId", BsonNull.VALUE).append("path", new BsonString("encryptedInt")) + .append("bsonType", new BsonString("int")) + .append("queries", + new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) + .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) + .append("min", new BsonInt32(0)).append("max", new BsonInt32(200))), + new BsonDocument("keyId", BsonNull.VALUE).append("path", new BsonString("encryptedLong")) + .append("bsonType", new BsonString("long")).append("queries", + new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) + .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) + .append("min", new BsonInt64(1000)).append("max", new BsonInt64(9999)))))); + +// BsonBinary dataKey1 = clientEncryption.createDataKey(LOCAL_KMS_PROVIDER, new DataKeyOptions().keyAltNames(List.of("dek-1"))); +// BsonBinary dataKey2 = clientEncryption.createDataKey(LOCAL_KMS_PROVIDER, new DataKeyOptions().keyAltNames(List.of("dek-2"))); +// +// BsonDocument encryptedFields = new BsonDocument().append("fields", +// new BsonArray(asList( +// new BsonDocument("keyId", dataKey1).append("path", new BsonString("encryptedInt")) +// .append("bsonType", new BsonString("int")) +// .append("queries", +// new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) +// .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) +// .append("min", new BsonInt32(0)).append("max", new BsonInt32(200))), +// new BsonDocument("keyId", dataKey2).append("path", new BsonString("encryptedLong")) +// .append("bsonType", new BsonString("long")).append("queries", +// new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) +// .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) +// .append("min", new BsonInt64(1000)).append("max", new BsonInt64(9999)))))); +// + + CollectionOptions.encrypted(options -> + options + .queryable( + encrypted(int32("encryptedInt")), + range().contention(0).trimFactor(1).sparsity(1).min(0).max(200) + ).queryable( + encrypted(int64("encryptedLong")), + range().contention(0).trimFactor(1).sparsity(1).min(1000L).max(9999L) + ) + + ); BsonDocument local = clientEncryption.createEncryptedCollection(database, "test", new CreateCollectionOptions().encryptedFields(encryptedFields), @@ -321,9 +337,9 @@ class RangeEncryptionTests { String name; @ExplicitEncrypted(algorithm = RANGE, contentionFactor = 0L, - rangeOptions = "{\"min\": 0, \"max\": 200, \"trimFactor\": 1, \"sparsity\": 1}", keyAltName = "dek-1") Integer encryptedInt; + rangeOptions = "{\"min\": 0, \"max\": 200, \"trimFactor\": 1, \"sparsity\": 1}") Integer encryptedInt; @ExplicitEncrypted(algorithm = RANGE, contentionFactor = 0L, - rangeOptions = "{\"min\": {\"$numberLong\": \"1000\"}, \"max\": {\"$numberLong\": \"9999\"}, \"trimFactor\": 1, \"sparsity\": 1}", keyAltName = "dek-2") Long encryptedLong; + rangeOptions = "{\"min\": {\"$numberLong\": \"1000\"}, \"max\": {\"$numberLong\": \"9999\"}, \"trimFactor\": 1, \"sparsity\": 1}") Long encryptedLong; public String getId() { return this.id;