diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java index afe3abdb0..8a3cfef78 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java @@ -46,7 +46,6 @@ import org.springframework.data.util.Optionals; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; -import org.springframework.util.StringUtils; import com.mongodb.client.model.ValidationAction; import com.mongodb.client.model.ValidationLevel; @@ -354,7 +353,8 @@ public class CollectionOptions { * @since 4.5.0 */ public static CollectionOptions encrypted(@Nullable EncryptedCollectionOptions encryptedCollectionOptions) { - return new CollectionOptions(null, null, null, null, null, null, null, encryptedCollectionOptions); + return new CollectionOptions(null, null, null, null, ValidationOptions.NONE, null, null, + encryptedCollectionOptions); } public static CollectionOptions encrypted(MongoJsonSchema schema) { @@ -665,9 +665,13 @@ public class CollectionOptions { } if (property .getTargetProperty() instanceof IdentifiableJsonSchemaProperty.EncryptedJsonSchemaProperty encrypted) { - if (StringUtils.hasText(encrypted.getKeyId())) { - field.append("keyId", - new BsonBinary(BsonBinarySubType.UUID_STANDARD, encrypted.getKeyId().getBytes(StandardCharsets.UTF_8))); + if (encrypted.getKeyId() != null) { + if (encrypted.getKeyId() instanceof String stringKey) { + field.append("keyId", + new BsonBinary(BsonBinarySubType.UUID_STANDARD, stringKey.getBytes(StandardCharsets.UTF_8))); + } else { + field.append("keyId", encrypted.getKeyId()); + } } } field.append("queries", property.getCharacteristics().getCharacteristics().stream() @@ -692,8 +696,7 @@ public class CollectionOptions { for (Entry entry : paths.entrySet()) { Document field = new Document("path", entry.getKey()); - field.append("keyId", - entry.getValue().containsValue("keyId") ? entry.getValue().get("keyId") : BsonNull.VALUE); + field.append("keyId", entry.getValue().getOrDefault("keyId", BsonNull.VALUE)); if (entry.getValue().containsKey("bsonType")) { field.append("bsonType", entry.getValue().get("bsonType")); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/IdentifiableJsonSchemaProperty.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/IdentifiableJsonSchemaProperty.java index 1e278b118..b77e97d48 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/IdentifiableJsonSchemaProperty.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/schema/IdentifiableJsonSchemaProperty.java @@ -1179,7 +1179,11 @@ public class IdentifiableJsonSchemaProperty implemen if (!ObjectUtils.isEmpty(keyId)) { enc.append("keyId", keyId); } else if (!ObjectUtils.isEmpty(keyIds)) { - enc.append("keyId", keyIds); + if(keyIds.size() == 1) { + enc.append("keyId", keyIds.iterator().next()); + } else { + enc.append("keyId", keyIds); + } } Type type = extractPropertyType(propertySpecification); @@ -1221,8 +1225,14 @@ public class IdentifiableJsonSchemaProperty implemen return null; } - public String getKeyId() { - return keyId; + public Object getKeyId() { + if(keyId != null) { + return keyId; + } + if(keyIds != null && keyIds.size() == 1) { + return keyIds.iterator().next(); + } + return null; } } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultIndexOperationsIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultIndexOperationsIntegrationTests.java index af4fac84b..78a6e6b49 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultIndexOperationsIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultIndexOperationsIntegrationTests.java @@ -15,9 +15,9 @@ */ package org.springframework.data.mongodb.core; -import static org.assertj.core.api.Assertions.*; -import static org.springframework.data.mongodb.core.index.PartialIndexFilter.*; -import static org.springframework.data.mongodb.core.query.Criteria.*; +import static org.springframework.data.mongodb.core.index.PartialIndexFilter.of; +import static org.springframework.data.mongodb.core.query.Criteria.where; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import org.bson.BsonDocument; import org.bson.Document; @@ -79,7 +79,7 @@ public class DefaultIndexOperationsIntegrationTests { IndexDefinition id = new Index().named("partial-with-criteria").on("k3y", Direction.ASC) .partial(of(where("q-t-y").gte(10))); - indexOps.ensureIndex(id); + indexOps.createIndex(id); IndexInfo info = findAndReturnIndexInfo(indexOps.getIndexInfo(), "partial-with-criteria"); assertThat(Document.parse(info.getPartialFilterExpression())) @@ -92,7 +92,7 @@ public class DefaultIndexOperationsIntegrationTests { IndexDefinition id = new Index().named("partial-with-mapped-criteria").on("k3y", Direction.ASC) .partial(of(where("quantity").gte(10))); - template.indexOps(DefaultIndexOperationsIntegrationTestsSample.class).ensureIndex(id); + template.indexOps(DefaultIndexOperationsIntegrationTestsSample.class).createIndex(id); IndexInfo info = findAndReturnIndexInfo(indexOps.getIndexInfo(), "partial-with-mapped-criteria"); assertThat(Document.parse(info.getPartialFilterExpression())) @@ -105,7 +105,7 @@ public class DefaultIndexOperationsIntegrationTests { IndexDefinition id = new Index().named("partial-with-dbo").on("k3y", Direction.ASC) .partial(of(new org.bson.Document("qty", new org.bson.Document("$gte", 10)))); - indexOps.ensureIndex(id); + indexOps.createIndex(id); IndexInfo info = findAndReturnIndexInfo(indexOps.getIndexInfo(), "partial-with-dbo"); assertThat(Document.parse(info.getPartialFilterExpression())) @@ -120,7 +120,7 @@ public class DefaultIndexOperationsIntegrationTests { indexOps = new DefaultIndexOperations(template, COLLECTION_NAME, MappingToSameCollection.class); - indexOps.ensureIndex(id); + indexOps.createIndex(id); IndexInfo info = findAndReturnIndexInfo(indexOps.getIndexInfo(), "partial-with-inheritance"); assertThat(Document.parse(info.getPartialFilterExpression())) @@ -150,7 +150,7 @@ public class DefaultIndexOperationsIntegrationTests { new DefaultIndexOperations(template, COLLECTION_NAME, MappingToSameCollection.class); - indexOps.ensureIndex(id); + indexOps.createIndex(id); Document expected = new Document("locale", "de_AT") // .append("caseLevel", false) // @@ -179,7 +179,7 @@ public class DefaultIndexOperationsIntegrationTests { IndexDefinition index = new Index().named("my-index").on("a", Direction.ASC); indexOps = new DefaultIndexOperations(template, COLLECTION_NAME, MappingToSameCollection.class); - indexOps.ensureIndex(index); + indexOps.createIndex(index); IndexInfo info = findAndReturnIndexInfo(indexOps.getIndexInfo(), "my-index"); assertThat(info.isHidden()).isFalse(); @@ -191,7 +191,7 @@ public class DefaultIndexOperationsIntegrationTests { IndexDefinition index = new Index().named("my-hidden-index").on("a", Direction.ASC).hidden(); indexOps = new DefaultIndexOperations(template, COLLECTION_NAME, MappingToSameCollection.class); - indexOps.ensureIndex(index); + indexOps.createIndex(index); IndexInfo info = findAndReturnIndexInfo(indexOps.getIndexInfo(), "my-hidden-index"); assertThat(info.isHidden()).isTrue(); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/MongoQueryableEncryptionCollectionCreationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/MongoQueryableEncryptionCollectionCreationTests.java new file mode 100644 index 000000000..ff5b1233b --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/MongoQueryableEncryptionCollectionCreationTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2025. the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.encryption; + +import static org.springframework.data.mongodb.core.schema.JsonSchemaProperty.encrypted; +import static org.springframework.data.mongodb.core.schema.JsonSchemaProperty.int32; +import static org.springframework.data.mongodb.core.schema.JsonSchemaProperty.int64; +import static org.springframework.data.mongodb.core.schema.JsonSchemaProperty.queryable; +import static org.springframework.data.mongodb.core.schema.QueryCharacteristics.range; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; + +import java.util.List; +import java.util.UUID; +import java.util.stream.Stream; + +import org.bson.BsonBinary; +import org.bson.Document; +import org.bson.UuidRepresentation; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.mongodb.config.AbstractMongoClientConfiguration; +import org.springframework.data.mongodb.core.CollectionOptions; +import org.springframework.data.mongodb.core.MongoTemplate; +import org.springframework.data.mongodb.core.schema.JsonSchemaProperty; +import org.springframework.data.mongodb.core.schema.MongoJsonSchema; +import org.springframework.data.mongodb.core.schema.QueryCharacteristics; +import org.springframework.data.mongodb.test.util.Client; +import org.springframework.data.mongodb.test.util.MongoClientExtension; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +import com.mongodb.client.MongoClient; + +/** + * @author Christoph Strobl + */ +@ExtendWith({ MongoClientExtension.class, SpringExtension.class }) +@ContextConfiguration +public class MongoQueryableEncryptionCollectionCreationTests { + + public static final String COLLECTION_NAME = "enc-collection"; + static @Client MongoClient mongoClient; + + @Configuration + static class Config extends AbstractMongoClientConfiguration { + + @Override + public MongoClient mongoClient() { + return mongoClient; + } + + @Override + protected String getDatabaseName() { + return "encryption-schema-tests"; + } + + } + + @Autowired MongoTemplate template; + + @BeforeEach + void beforeEach() { + template.dropCollection(COLLECTION_NAME); + } + + @ParameterizedTest // GH-4185 + @MethodSource("collectionOptions") + public void createsCollectionWithEncryptedFieldsCorrectly(CollectionOptions collectionOptions) { + + template.createCollection(COLLECTION_NAME, collectionOptions); + + Document encryptedFields = readEncryptedFieldsFromDatabase(COLLECTION_NAME); + assertThat(encryptedFields).containsKey("fields"); + + List fields = encryptedFields.get("fields", List.of()); + assertThat(fields.get(0)).containsEntry("path", "encryptedInt") // + .containsEntry("bsonType", "int") // + .containsEntry("queries", List + .of(Document.parse("{'queryType': 'range', 'contention': { '$numberLong' : '1' }, 'min': 5, 'max': 100}"))); + + assertThat(fields.get(1)).containsEntry("path", "nested.encryptedLong") // + .containsEntry("bsonType", "long") // + .containsEntry("queries", List.of(Document.parse( + "{'queryType': 'range', 'contention': { '$numberLong' : '0' }, 'min': { '$numberLong' : '-1' }, 'max': { '$numberLong' : '1' }}"))); + } + + private static Stream collectionOptions() { + + BsonBinary key1 = new BsonBinary(UUID.randomUUID(), UuidRepresentation.STANDARD); + BsonBinary key2 = new BsonBinary(UUID.randomUUID(), UuidRepresentation.STANDARD); + + CollectionOptions manualOptions = CollectionOptions.encrypted(options -> options // + .queryable(encrypted(int32("encryptedInt")).keys(key1), range().min(5).max(100).contention(1)) // + .queryable(encrypted(JsonSchemaProperty.int64("nested.encryptedLong")).keys(key2), + range().min(-1L).max(1L).contention(0))); + + CollectionOptions schemaOptions = CollectionOptions.encrypted(MongoJsonSchema.builder() + .property(queryable(encrypted(int32("encryptedInt")).keys(key1), + new QueryCharacteristics(List.of(range().min(5).max(100).contention(1))))) + .property(queryable(encrypted(int64("nested.encryptedLong")).keys(key2), + new QueryCharacteristics(List.of(range().min(-1L).max(1L).contention(0))))) + .build()); + + return Stream.of(Arguments.of(manualOptions), Arguments.of(schemaOptions)); + } + + Document readEncryptedFieldsFromDatabase(String collectionName) { + + Document collectionInfo = template + .executeCommand(new Document("listCollections", 1).append("filter", new Document("name", collectionName))); + + if (collectionInfo.containsKey("cursor")) { + collectionInfo = (Document) collectionInfo.get("cursor", Document.class).get("firstBatch", List.class).iterator() + .next(); + } + + if (!collectionInfo.containsKey("options")) { + return new Document(); + } + + return collectionInfo.get("options", Document.class).get("encryptedFields", Document.class); + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTests.java index 78a61b448..5c8935a30 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTests.java @@ -15,7 +15,6 @@ */ package org.springframework.data.mongodb.core.encryption; -import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.springframework.data.mongodb.core.query.Criteria.where; @@ -28,13 +27,9 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import java.util.stream.Collectors; -import org.bson.BsonArray; import org.bson.BsonBinary; import org.bson.BsonDocument; import org.bson.BsonInt32; -import org.bson.BsonInt64; -import org.bson.BsonNull; -import org.bson.BsonString; import org.bson.BsonValue; import org.bson.Document; import org.junit.jupiter.api.AfterEach; @@ -47,6 +42,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.data.convert.PropertyValueConverterFactory; import org.springframework.data.mongodb.config.AbstractMongoClientConfiguration; import org.springframework.data.mongodb.core.CollectionOptions; +import org.springframework.data.mongodb.core.CollectionOptions.EncryptedCollectionOptions; import org.springframework.data.mongodb.core.MongoJsonSchemaCreator; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.convert.MongoCustomConversions.MongoConverterConfigurationAdapter; @@ -120,11 +116,10 @@ class RangeEncryptionTests { Document result = template.execute(Person.class, col -> { BsonDocument filterSource = new BsonDocument("encryptedInt", new BsonDocument("$gte", new BsonInt32(100))); - BsonDocument filter = clientEncryption.getClientEncryption().encryptExpression( - new Document("$and", List.of(filterSource)), - encryptExpressionOptions); + BsonDocument filter = clientEncryption.getClientEncryption() + .encryptExpression(new Document("$and", List.of(filterSource)), encryptExpressionOptions); Document first = col.find(filter).first(); -// Document first = col.find(filterSource).first(); + // Document first = col.find(filterSource).first(); System.out.println("first.toJson(): " + first.toJson()); return first; }); @@ -240,29 +235,36 @@ class RangeEncryptionTests { ClientEncryption clientEncryption = mongoClientEncryption.getClientEncryption(); - BsonDocument encryptedFields = new BsonDocument().append("fields", - new BsonArray(asList( - new BsonDocument("keyId", BsonNull.VALUE).append("path", new BsonString("encryptedInt")) - .append("bsonType", new BsonString("int")) - .append("queries", - new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) - .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) - .append("min", new BsonInt32(0)).append("max", new BsonInt32(200))), - new BsonDocument("keyId", BsonNull.VALUE).append("path", new BsonString("encryptedLong")) - .append("bsonType", new BsonString("long")).append("queries", - new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) - .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) - .append("min", new BsonInt64(1000)).append("max", new BsonInt64(9999)))))); - - MongoJsonSchema personSchema = MongoJsonSchemaCreator.create(new MongoMappingContext()) - .filter(MongoJsonSchemaCreator.encryptedOnly()).createSchemaFor(Person.class); - - CollectionOptions options = CollectionOptions.encrypted(personSchema); + // BsonDocument encryptedFields = new BsonDocument().append("fields", + // new BsonArray(asList( + // new BsonDocument("keyId", BsonNull.VALUE).append("path", new BsonString("encryptedInt")) + // .append("bsonType", new BsonString("int")) + // .append("queries", + // new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) + // .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) + // .append("min", new BsonInt32(0)).append("max", new BsonInt32(200))), + // new BsonDocument("keyId", BsonNull.VALUE).append("path", new BsonString("encryptedLong")) + // .append("bsonType", new BsonString("long")).append("queries", + // new BsonDocument("queryType", new BsonString("range")).append("contention", new BsonInt64(0L)) + // .append("trimFactor", new BsonInt32(1)).append("sparsity", new BsonInt64(1)) + // .append("min", new BsonInt64(1000)).append("max", new BsonInt64(9999)))))); + + MongoJsonSchema personSchema = MongoJsonSchemaCreator.create(new MongoMappingContext()) // init schema creator + .filter(MongoJsonSchemaCreator.encryptedOnly()) // should be obvious + .createSchemaFor(Person.class); // create it for given type + + Document encryptedFields = CollectionOptions.encrypted(personSchema) // pass in the schema + .getEncryptedFields() // get the fields just because we need to use createEncryptedCollection which not + // part of the driver + .map(EncryptedCollectionOptions::toDocument) // now map them into the raw format + .orElseThrow(); + + CreateCollectionOptions createCollectionOptions = new CreateCollectionOptions() + .encryptedFields(encryptedFields); // that's it BsonDocument local = clientEncryption.createEncryptedCollection(database, "test", // new CreateCollectionOptions().encryptedFields(encryptedFields), - new CreateCollectionOptions().encryptedFields(options.getEncryptedFields().get().toDocument()), - new CreateEncryptedCollectionParams(LOCAL_KMS_PROVIDER)); + createCollectionOptions, new CreateEncryptedCollectionParams(LOCAL_KMS_PROVIDER)); return local.getArray("fields").stream().map(BsonValue::asDocument).collect( Collectors.toMap(field -> field.getString("path").getValue(), field -> field.getBinary("keyId"))); @@ -292,8 +294,7 @@ class RangeEncryptionTests { builder.autoEncryptionSettings(AutoEncryptionSettings.builder() // .kmsProviders(clientEncryptionSettings.getKmsProviders()) // .keyVaultNamespace(clientEncryptionSettings.getKeyVaultNamespace()) // - .bypassQueryAnalysis(true) - .build()); + .bypassQueryAnalysis(true).build()); } }