From a53ea25de266d526c5ce4f00889634b5f213bf17 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 4 Feb 2025 09:42:51 +0100 Subject: [PATCH] Polishing. Remove Field type. Refactor container to subclass MongoDBAtlasLocalContainer. Introduce wait/synchronization to avoid container crashes on create index + list search indexes. See #4706 Original pull request: #4882 --- .../core/MappingMongoJsonSchemaCreator.java | 6 +- .../aggregation/VectorSearchOperation.java | 85 +++++++----------- .../mongodb/core/convert/MongoConverters.java | 5 +- .../mongodb/core/index/IndexOperations.java | 6 +- .../core/index/ReactiveIndexOperations.java | 14 +++ .../core/index/SearchIndexDefinition.java | 8 +- .../mongodb/core/index/SearchIndexInfo.java | 18 +++- .../core/index/SearchIndexOperations.java | 12 --- .../data/mongodb/core/index/VectorIndex.java | 47 +++++----- .../data/mongodb/core/mapping/FieldType.java | 5 +- .../VectorSearchOperationUnitTests.java | 18 ++-- .../core/aggregation/VectorSearchTests.java | 34 ++++---- .../MappingMongoConverterUnitTests.java | 57 ++---------- .../index/VectorIndexIntegrationTests.java | 40 ++++++--- .../mongodb/test/util/AtlasContainer.java | 86 ++++--------------- .../mongodb/test/util/MongoTestTemplate.java | 17 ++-- 16 files changed, 189 insertions(+), 269 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java index a4c852ef1..86e01afc2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java @@ -185,7 +185,7 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator { Class rawTargetType = computeTargetType(property); // target type before conversion Class targetType = converter.getTypeMapper().getWriteTargetTypeFor(rawTargetType); // conversion target type - if((rawTargetType.isPrimitive() || ClassUtils.isPrimitiveArray(rawTargetType)) && targetType == Object.class) { + if ((rawTargetType.isPrimitive() || ClassUtils.isPrimitiveArray(rawTargetType)) && targetType == Object.class) { targetType = rawTargetType; } @@ -338,8 +338,8 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator { private String computePropertyFieldName(PersistentProperty property) { - return property instanceof MongoPersistentProperty mongoPersistentProperty ? - mongoPersistentProperty.getFieldName() : property.getName(); + return property instanceof MongoPersistentProperty mongoPersistentProperty ? mongoPersistentProperty.getFieldName() + : property.getName(); } private boolean isRequiredProperty(PersistentProperty property) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java index a8a1cf892..bcc5fbd7b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java @@ -16,15 +16,14 @@ package org.springframework.data.mongodb.core.aggregation; import java.util.Arrays; -import java.util.LinkedHashSet; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; import org.bson.BinaryVector; import org.bson.Document; + import org.springframework.data.domain.Limit; import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.mapping.MongoVector; @@ -177,7 +176,7 @@ public class VectorSearchOperation implements AggregationOperation { * can't specify a number less than the number of documents to return (limit). This field is required if * {@link #searchType(SearchType)} is {@link SearchType#ANN} or {@link SearchType#DEFAULT}. * - * @param numCandidates + * @param numCandidates number of nearest neighbors to use during the search * @return a new {@link VectorSearchOperation} with {@code numCandidates} applied. */ @Contract("_ -> new") @@ -338,20 +337,25 @@ public class VectorSearchOperation implements AggregationOperation { ENN } - // A query path cannot only contain the name of the filed but may also hold additional information about the - // analyzer to use; - // "path": [ "names", "notes", { "value": "comments", "multi": "mySecondaryAnalyzer" } ] - // see: https://www.mongodb.com/docs/atlas/atlas-search/path-construction/#std-label-ref-path + /** + * Value object capturing query paths. + */ public static class QueryPaths { - Set> paths; + private final Set> paths; - public static QueryPaths of(QueryPath path) { + private QueryPaths(Set> paths) { + this.paths = paths; + } - QueryPaths queryPaths = new QueryPaths(); - queryPaths.paths = new LinkedHashSet<>(2); - queryPaths.paths.add(path); - return queryPaths; + /** + * Factory method to create {@link QueryPaths} from a single {@link QueryPath}. + * + * @param path + * @return a new {@link QueryPaths} instance. + */ + public static QueryPaths of(QueryPath path) { + return new QueryPaths(Set.of(path)); } Object getPathObject() { @@ -363,6 +367,12 @@ public class VectorSearchOperation implements AggregationOperation { } } + /** + * Interface describing a query path contract. Query paths might be simple field names, wildcard paths, or + * multi-paths. paths. + * + * @param + */ public interface QueryPath { T value(); @@ -370,14 +380,6 @@ public class VectorSearchOperation implements AggregationOperation { static QueryPath path(String field) { return new SimplePath(field); } - - static QueryPath> wildcard(String field) { - return new WildcardPath(field); - } - - static QueryPath> multi(String field, String analyzer) { - return new MultiPath(field, analyzer); - } } public static class SimplePath implements QueryPath { @@ -394,36 +396,9 @@ public class VectorSearchOperation implements AggregationOperation { } } - public static class WildcardPath implements QueryPath> { - - String name; - - public WildcardPath(String name) { - this.name = name; - } - - @Override - public Map value() { - return Map.of("wildcard", name); - } - } - - public static class MultiPath implements QueryPath> { - - String field; - String analyzer; - - public MultiPath(String field, String analyzer) { - this.field = field; - this.analyzer = analyzer; - } - - @Override - public Map value() { - return Map.of("value", field, "multi", analyzer); - } - } - + /** + * Fluent API to configure a path on the VectorSearchOperation builder. + */ public interface PathContributor { /** @@ -436,6 +411,9 @@ public class VectorSearchOperation implements AggregationOperation { VectorContributor path(String path); } + /** + * Fluent API to configure a vector on the VectorSearchOperation builder. + */ public interface VectorContributor { /** @@ -458,7 +436,7 @@ public class VectorSearchOperation implements AggregationOperation { * @return */ @Contract("_ -> this") - default LimitContributor vector(byte... vector) { + default LimitContributor vector(byte[] vector) { return vector(BinaryVector.int8Vector(vector)); } @@ -510,6 +488,9 @@ public class VectorSearchOperation implements AggregationOperation { LimitContributor vector(Vector vector); } + /** + * Fluent API to configure a limit on the VectorSearchOperation builder. + */ public interface LimitContributor { /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java index 03216d096..9a658c44b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java @@ -15,7 +15,7 @@ */ package org.springframework.data.mongodb.core.convert; -import static org.springframework.data.convert.ConverterBuilder.reading; +import static org.springframework.data.convert.ConverterBuilder.*; import java.math.BigDecimal; import java.math.BigInteger; @@ -47,6 +47,7 @@ import org.bson.types.Binary; import org.bson.types.Code; import org.bson.types.Decimal128; import org.bson.types.ObjectId; + import org.springframework.core.convert.ConversionFailedException; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.ConditionalConverter; @@ -118,8 +119,6 @@ abstract class MongoConverters { converters.add(reading(BsonUndefined.class, Object.class, it -> null)); converters.add(reading(String.class, URI.class, URI::create).andWriting(URI::toString)); - converters.add(ByteArrayConverterFactory.INSTANCE); - return converters; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java index fe2e569a4..88e6d7a81 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java @@ -33,14 +33,14 @@ public interface IndexOperations { * * @param indexDefinition must not be {@literal null}. * @return the index name. - * @deprecated in favor of {@link #createIndex(IndexDefinition)}. + * @deprecated since 4.5, in favor of {@link #createIndex(IndexDefinition)}. */ @Deprecated(since = "4.5", forRemoval = true) String ensureIndex(IndexDefinition indexDefinition); /** - * Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity - * class. If not it will be created. + * Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity class. + * If not it will be created. * * @param indexDefinition must not be {@literal null}. * @return the index name. diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java index c0fc06569..15b110c08 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java @@ -33,9 +33,23 @@ public interface ReactiveIndexOperations { * * @param indexDefinition must not be {@literal null}. * @return a {@link Mono} emitting the name of the index on completion. + * @deprecated since 4.5, in favor of {@link #createIndex(IndexDefinition)}. */ + @Deprecated(since = "4.5", forRemoval = true) Mono ensureIndex(IndexDefinition indexDefinition); + /** + * Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity class. + * If not it will be created. + * + * @param indexDefinition must not be {@literal null}. + * @return the index name. + * @since 4.5 + */ + default Mono createIndex(IndexDefinition indexDefinition) { + return ensureIndex(indexDefinition); + } + /** * Alters the index with given {@literal name}. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java index 2cb4eff0e..9d4315bea 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java @@ -45,7 +45,7 @@ public interface SearchIndexDefinition { * Returns the index document for this index without any potential entity context resolving field name mappings. The * resulting document contains the index name, type and {@link #getDefinition(TypeInformation, MappingContext) * definition}. - * + * * @return never {@literal null}. */ default Document getRawIndexDocument() { @@ -74,12 +74,14 @@ public interface SearchIndexDefinition { /** * Returns the actual index definition for this index in the context of a potential entity to resolve field name - * mappings. + * mappings. Entity and context can be {@literal null} to create a generic index definition without applying field + * name mapping. * * @param entity can be {@literal null}. - * @param mappingContext + * @param mappingContext can be {@literal null}. * @return never {@literal null}. */ Document getDefinition(@Nullable TypeInformation entity, @Nullable MappingContext, MongoPersistentProperty> mappingContext); + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java index 01f4374f4..1a657ecf0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java @@ -27,8 +27,9 @@ import org.springframework.lang.Nullable; /** * Index information for a MongoDB Search Index. - * + * * @author Christoph Strobl + * @since 4.5 */ public class SearchIndexInfo { @@ -42,14 +43,27 @@ public class SearchIndexInfo { this.indexDefinition = Lazy.of(indexDefinition); } + /** + * Parse a BSON document describing an index into a {@link SearchIndexInfo}. + * + * @param source BSON document describing the index. + * @return a new {@link SearchIndexInfo} instance. + */ public static SearchIndexInfo parse(String source) { return of(Document.parse(source)); } + /** + * Create an index from its BSON {@link Document} representation into a {@link SearchIndexInfo}. + * + * @param indexDocument BSON document describing the index. + * @return a new {@link SearchIndexInfo} instance. + */ public static SearchIndexInfo of(Document indexDocument) { Object id = indexDocument.get("id"); - SearchIndexStatus status = SearchIndexStatus.valueOf(indexDocument.get("status", "DOES_NOT_EXIST")); + SearchIndexStatus status = SearchIndexStatus + .valueOf(indexDocument.get("status", SearchIndexStatus.DOES_NOT_EXIST.name())); return new SearchIndexInfo(id, status, () -> readIndexDefinition(indexDocument)); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java index d68b547a3..ee3f59cf9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java @@ -27,17 +27,6 @@ import org.springframework.dao.DataAccessException; */ public interface SearchIndexOperations { - /** - * Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class. - * - * @param indexDefinition must not be {@literal null}. - * @return the index name. - */ - // TODO: keep or just go with createIndex? - default String ensureIndex(SearchIndexDefinition indexDefinition) { - return createIndex(indexDefinition); - } - /** * Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class. * @@ -53,7 +42,6 @@ public interface SearchIndexOperations { * * @param indexDefinition the index definition. */ - // TODO: keep or remove since it does not work reliably? void updateIndex(SearchIndexDefinition indexDefinition); /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java index 20cf2a8ff..b46dbf4d0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java @@ -1,27 +1,11 @@ /* - * Copyright 2024. the original author or authors. + * Copyright 2024-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Copyright 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -153,12 +137,16 @@ public class VectorIndex implements SearchIndexDefinition { return "VectorIndex{" + "name='" + name + '\'' + ", fields=" + fields + ", type='" + getType() + '\'' + '}'; } - // /** instead of index info */ + /** + * Parse the {@link Document} into a {@link VectorIndex}. + */ static VectorIndex of(Document document) { VectorIndex index = new VectorIndex(document.getString("name")); + String definitionKey = document.containsKey("latestDefinition") ? "latestDefinition" : "definition"; Document definition = document.get(definitionKey, Document.class); + for (Object entry : definition.get("fields", List.class)) { if (entry instanceof Document field) { if (field.get("type").equals("vector")) { @@ -195,7 +183,7 @@ public class VectorIndex implements SearchIndexDefinition { record VectorFilterField(String path, String type) implements SearchField { } - record VectorIndexField(String path, String type, int dimensions, String similarity, + record VectorIndexField(String path, String type, int dimensions, @Nullable String similarity, @Nullable String quantization) implements SearchField { } @@ -313,6 +301,9 @@ public class VectorIndex implements SearchIndexDefinition { } } + /** + * Similarity function used to calculate vector distance. + */ public enum SimilarityFunction { DOT_PRODUCT("dotProduct"), COSINE("cosine"), EUCLIDEAN("euclidean"); @@ -328,10 +319,22 @@ public class VectorIndex implements SearchIndexDefinition { } } - /** make it nullable */ + /** + * Vector quantization. Quantization reduce vector sizes while preserving performance. + */ public enum Quantization { - NONE("none"), SCALAR("scalar"), BINARY("binary"); + NONE("none"), + + /** + * Converting a float point into an integer. + */ + SCALAR("scalar"), + + /** + * Converting a float point into a single bit. + */ + BINARY("binary"); final String quantizationName; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java index 721807c26..7fc4199dd 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java @@ -18,8 +18,6 @@ package org.springframework.data.mongodb.core.mapping; import java.util.Date; import java.util.regex.Pattern; -import org.bson.BinaryVector; -import org.bson.BsonBinary; import org.bson.types.BSONTimestamp; import org.bson.types.Binary; import org.bson.types.Code; @@ -57,8 +55,7 @@ public enum FieldType { INT32(15, Integer.class), // TIMESTAMP(16, BSONTimestamp.class), // INT64(17, Long.class), // - DECIMAL128(18, Decimal128.class), - VECTOR(5, BinaryVector.class); + DECIMAL128(18, Decimal128.class); private final int bsonType; private final Class javaClass; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java index 69348290f..4ce045fe6 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java @@ -15,11 +15,13 @@ */ package org.springframework.data.mongodb.core.aggregation; +import static org.assertj.core.api.Assertions.*; + import java.util.List; -import org.assertj.core.api.Assertions; import org.bson.Document; import org.junit.jupiter.api.Test; + import org.springframework.data.annotation.Id; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.core.mapping.Field; @@ -27,6 +29,8 @@ import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.util.aggregation.TestAggregationContext; /** + * Unit tests for {@link VectorSearchOperation}. + * * @author Christoph Strobl */ class VectorSearchOperationUnitTests { @@ -40,7 +44,7 @@ class VectorSearchOperationUnitTests { void requiredArgs() { List stages = SEARCH_OPERATION.toPipelineStages(Aggregation.DEFAULT_CONTEXT); - Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH)); + assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH)); } @Test // GH-4706 @@ -53,7 +57,7 @@ class VectorSearchOperationUnitTests { Document filter = new Document("$and", List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975)))); - Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", + assertThat(stages).containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("exact", true).append("filter", filter).append("numCandidates", 150))); } @@ -61,7 +65,7 @@ class VectorSearchOperationUnitTests { void withScore() { List stages = SEARCH_OPERATION.withSearchScore().toPipelineStages(Aggregation.DEFAULT_CONTEXT); - Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), + assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore")))); } @@ -70,7 +74,7 @@ class VectorSearchOperationUnitTests { List stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50)) .toPipelineStages(Aggregation.DEFAULT_CONTEXT); - Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), + assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore"))), new Document("$match", new Document("score", new Document("$gt", 50)))); } @@ -80,7 +84,7 @@ class VectorSearchOperationUnitTests { List stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50)).withSearchScore("s-c-o-r-e") .toPipelineStages(Aggregation.DEFAULT_CONTEXT); - Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), + assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), new Document("$addFields", new Document("s-c-o-r-e", new Document("$meta", "vectorSearchScore"))), new Document("$match", new Document("s-c-o-r-e", new Document("$gt", 50)))); } @@ -95,7 +99,7 @@ class VectorSearchOperationUnitTests { Document filter = new Document("$and", List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975)))); - Assertions.assertThat(stages) + assertThat(stages) .containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("filter", filter))); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java index 1dded6d22..18991c176 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java @@ -15,7 +15,7 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.*; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -28,15 +28,15 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; + import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.core.index.VectorIndex; import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; -import org.springframework.data.mongodb.core.mapping.Field; -import org.springframework.data.mongodb.core.mapping.FieldType; import org.springframework.data.mongodb.core.mapping.MongoVector; import org.springframework.data.mongodb.test.util.AtlasContainer; import org.springframework.data.mongodb.test.util.MongoTestTemplate; + import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -44,16 +44,20 @@ import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; /** + * Integration tests using Vector Search and Vector Indexes through local MongoDB Atlas. + * * @author Christoph Strobl + * @author Mark Paluch */ @Testcontainers(disabledWithoutDocker = true) public class VectorSearchTests { - public static final String SCORE_FIELD = "vector-search-tests"; - static final String COLLECTION_NAME = "collection-1"; + private static final String SCORE_FIELD = "vector-search-tests"; + private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); + private static final String COLLECTION_NAME = "collection-1"; + static MongoClient client; static MongoTestTemplate template; - private static @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); @BeforeAll static void beforeAll() throws InterruptedException { @@ -126,12 +130,12 @@ public class VectorSearchTests { return Stream.of(// Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat32vector") // - .vector(new float[] { 0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f }) // + .vector(0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f) // .limit(10)// .numCandidates(20) // .searchType(SearchType.ANN)), Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat64vector") // - .vector(new double[] { 1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d }) // + .vector(1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d) // .limit(10)// .numCandidates(20) // .searchType(SearchType.ANN)), @@ -160,8 +164,8 @@ public class VectorSearchTests { .addVector("float64vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)) .addFilter("justSomeArgument"); - template.searchIndexOps(WithVectorFields.class).ensureIndex(rawIndex); - template.searchIndexOps(WithVectorFields.class).ensureIndex(wrapperIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(rawIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(wrapperIndex); template.awaitIndexCreation(WithVectorFields.class, rawIndex.getName()); template.awaitIndexCreation(WithVectorFields.class, wrapperIndex.getName()); @@ -188,8 +192,7 @@ public class VectorSearchTests { Vector float32vector; Vector float64vector; - @Field(targetType = FieldType.VECTOR) // - byte[] rawInt8vector; + BinaryVector rawInt8vector; float[] rawFloat32vector; double[] rawFloat64vector; @@ -199,15 +202,16 @@ public class VectorSearchTests { WithVectorFields instance = new WithVectorFields(); instance.id = "id-%s".formatted(offset); - instance.rawInt8vector = new byte[5]; instance.rawFloat32vector = new float[5]; instance.rawFloat64vector = new double[5]; + byte[] int8 = new byte[5]; for (int i = 0; i < 5; i++) { int v = i + offset; - instance.rawInt8vector[i] = (byte) v; + int8[i] = (byte) v; } + instance.rawInt8vector = BinaryVector.int8Vector(int8); if (offset == 0) { instance.rawFloat32vector[0] = 0.0001f; @@ -227,7 +231,7 @@ public class VectorSearchTests { instance.justSomeArgument = offset; - instance.int8vector = MongoVector.of(BinaryVector.int8Vector(instance.rawInt8vector)); + instance.int8vector = MongoVector.of(instance.rawInt8vector); instance.float32vector = MongoVector.of(BinaryVector.floatVector(instance.rawFloat32vector)); instance.float64vector = Vector.of(instance.rawFloat64vector); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java index 52f80ffbd..b5d1f72e1 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java @@ -15,23 +15,10 @@ */ package org.springframework.data.mongodb.core.convert; -import static java.time.ZoneId.systemDefault; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertThatNoException; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.fail; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.data.mongodb.core.DocumentTestUtils.assertTypeHint; -import static org.springframework.data.mongodb.core.DocumentTestUtils.getAsDocument; +import static java.time.ZoneId.*; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; +import static org.springframework.data.mongodb.core.DocumentTestUtils.*; import java.math.BigDecimal; import java.math.BigInteger; @@ -40,30 +27,12 @@ import java.nio.ByteBuffer; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.temporal.ChronoUnit; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Date; -import java.util.EnumMap; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.SortedMap; -import java.util.TreeMap; -import java.util.UUID; +import java.util.*; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Stream; import org.assertj.core.data.Percentage; -import org.bson.BinaryVector; import org.bson.BsonDouble; import org.bson.BsonUndefined; import org.bson.types.Binary; @@ -81,6 +50,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.aop.framework.ProxyFactory; import org.springframework.beans.ConversionNotSupportedException; import org.springframework.beans.factory.annotation.Autowired; @@ -3380,18 +3350,6 @@ class MappingMongoConverterUnitTests { assertThat(withVector.embeddings.toDoubleArray()).contains(1.1d, 2.2d, 3.3d); } - @Test // GH-4706 - void mapsByteArrayAsVectorWhenAnnotatedWithFieldTargetType() { - - WithExplicitTargetTypes source = new WithExplicitTargetTypes(); - source.asVector = new byte[] { 0, 1, 2 }; - - org.bson.Document target = new org.bson.Document(); - converter.write(source, target); - - assertThatNoException().isThrownBy(() -> target.get("asVector", BinaryVector.class)); - } - @Test // GH-4706 void writesByteArrayAsIsIfNoFieldInstructionsGiven() { @@ -4070,9 +4028,6 @@ class MappingMongoConverterUnitTests { @Field(targetType = FieldType.OBJECT_ID) // Date dateAsObjectId; - - @Field(targetType = FieldType.VECTOR) // - byte[] asVector; } static class WrapperAroundWithUnwrapped { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java index ad4adfa39..dcd447f81 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java @@ -15,8 +15,8 @@ */ package org.springframework.data.mongodb.core.index; -import static org.assertj.core.api.Assertions.assertThatRuntimeException; -import static org.awaitility.Awaitility.await; +import static org.assertj.core.api.Assertions.*; +import static org.awaitility.Awaitility.*; import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import java.util.List; @@ -27,6 +27,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; + import org.springframework.data.annotation.Id; import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; import org.springframework.data.mongodb.core.mapping.Field; @@ -34,6 +35,7 @@ import org.springframework.data.mongodb.test.util.AtlasContainer; import org.springframework.data.mongodb.test.util.MongoTestTemplate; import org.springframework.data.mongodb.test.util.MongoTestUtils; import org.springframework.lang.Nullable; + import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -49,7 +51,7 @@ import com.mongodb.client.AggregateIterable; @Testcontainers(disabledWithoutDocker = true) class VectorIndexIntegrationTests { - private static @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); + private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); MongoTestTemplate template = new MongoTestTemplate(cfg -> { cfg.configureDatabaseFactory(ctx -> { @@ -82,7 +84,7 @@ class VectorIndexIntegrationTests { VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity(similarityFunction)); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); await().untilAsserted(() -> { Document raw = readRawIndexInfo(idx.getName()); @@ -101,7 +103,7 @@ class VectorIndexIntegrationTests { VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity("cosine")); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); template.awaitIndexCreation(Movie.class, idx.getName()); @@ -111,7 +113,7 @@ class VectorIndexIntegrationTests { } @Test // GH-4706 - void statusChanges() { + void statusChanges() throws InterruptedException { String indexName = "vector_index"; assertThat(indexOps.status(indexName)).isEqualTo(SearchIndexStatus.DOES_NOT_EXIST); @@ -119,14 +121,17 @@ class VectorIndexIntegrationTests { VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity("cosine")); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); + + // without synchronization, the container might crash. + Thread.sleep(500); assertThat(indexOps.status(indexName)).isIn(SearchIndexStatus.PENDING, SearchIndexStatus.BUILDING, SearchIndexStatus.READY); } @Test // GH-4706 - void exists() { + void exists() throws InterruptedException { String indexName = "vector_index"; assertThat(indexOps.exists(indexName)).isFalse(); @@ -134,19 +139,25 @@ class VectorIndexIntegrationTests { VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity("cosine")); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); + + // without synchronization, the container might crash. + Thread.sleep(500); assertThat(indexOps.exists(indexName)).isTrue(); } @Test // GH-4706 - void updatesVectorIndex() { + void updatesVectorIndex() throws InterruptedException { String indexName = "vector_index"; VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity("cosine")); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); + + // without synchronization, the container might crash. + Thread.sleep(500); await().untilAsserted(() -> { Document raw = readRawIndexInfo(idx.getName()); @@ -166,13 +177,16 @@ class VectorIndexIntegrationTests { } @Test // GH-4706 - void createsVectorIndexWithFilters() { + void createsVectorIndexWithFilters() throws InterruptedException { VectorIndex idx = new VectorIndex("vector_index") .addVector("plotEmbedding", builder -> builder.dimensions(1536).cosine()).addFilter("description") .addFilter("year"); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); + + // without synchronization, the container might crash. + Thread.sleep(500); await().untilAsserted(() -> { Document raw = readRawIndexInfo(idx.getName()); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java index 406d1308b..c3a97a03b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java @@ -15,96 +15,44 @@ */ package org.springframework.data.mongodb.test.util; -import java.util.List; - -import org.bson.Document; import org.springframework.core.env.StandardEnvironment; -import org.springframework.data.util.Lazy; -import org.springframework.util.StringUtils; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.DockerHealthcheckWaitStrategy; -import org.testcontainers.containers.wait.strategy.WaitStrategy; -import org.testcontainers.utility.DockerImageName; -import com.mongodb.ConnectionString; -import com.mongodb.client.MongoClient; -import com.mongodb.client.MongoCollection; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; +import org.testcontainers.utility.DockerImageName; /** + * Extension to MongoDBAtlasLocalContainer. + * * @author Christoph Strobl */ -public class AtlasContainer extends GenericContainer { +public class AtlasContainer extends MongoDBAtlasLocalContainer { private static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("mongodb/mongodb-atlas-local"); - private static final String DEFAULT_TAG = "latest"; - private static final String MONGODB_DATABASE_NAME_DEFAULT = "test"; - private static final String READY_DB = "__db_ready_check"; - private final Lazy client; + private static final String DEFAULT_TAG = "8.0.0"; + private static final String LATEST = "latest"; + + private AtlasContainer(String dockerImageName) { + super(DockerImageName.parse(dockerImageName)); + } + + private AtlasContainer(DockerImageName dockerImageName) { + super(dockerImageName); + } public static AtlasContainer bestMatch() { return tagged(new StandardEnvironment().getProperty("mongodb.atlas.version", DEFAULT_TAG)); } public static AtlasContainer latest() { - return tagged(DEFAULT_TAG); + return tagged(LATEST); } public static AtlasContainer version8() { - return tagged("8.0.0"); + return tagged(DEFAULT_TAG); } public static AtlasContainer tagged(String tag) { return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag)); } - public AtlasContainer(String dockerImageName) { - this(DockerImageName.parse(dockerImageName)); - } - - public AtlasContainer(DockerImageName dockerImageName) { - - super(dockerImageName); - dockerImageName.assertCompatibleWith(DEFAULT_IMAGE_NAME); - setExposedPorts(List.of(27017)); - client = Lazy.of(() -> MongoTestUtils.client(new ConnectionString(getConnectionString()))); - } - - public String getConnectionString() { - return getConnectionString(MONGODB_DATABASE_NAME_DEFAULT); - } - - /** - * Gets a connection string url. - * - * @return a connection url pointing to a mongodb instance - */ - public String getConnectionString(String database) { - return String.format("mongodb://%s:%d/%s?directConnection=true", getHost(), getMappedPort(27017), - StringUtils.hasText(database) ? database : MONGODB_DATABASE_NAME_DEFAULT); - } - - @Override - public boolean isHealthy() { - - MongoClient mongoClient = client.get(); - MongoCollection ready = MongoTestUtils.createOrReplaceCollection(READY_DB, "ready", mongoClient); - boolean isReady = false; - - try { - ready.aggregate(List.of(new Document("$listSearchIndexes", new Document()))).first(); - isReady = true; - } catch (Exception e) { - // ok so the search service is not ready yet - sigh - } - if (isReady) { - mongoClient.getDatabase(READY_DB).drop(); - mongoClient.close(); - } - return isReady; - } - - @Override - protected WaitStrategy getWaitStrategy() { - return new DockerHealthcheckWaitStrategy(); - } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java index 1b72e6034..40948a0e2 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java @@ -63,14 +63,11 @@ public class MongoTestTemplate extends MongoTemplate { public MongoTestTemplate(Consumer cfg) { - this(new Supplier() { - @Override - public MongoTestTemplateConfiguration get() { + this(() -> { - MongoTestTemplateConfiguration config = new MongoTestTemplateConfiguration(); - cfg.accept(config); - return config; - } + MongoTestTemplateConfiguration config = new MongoTestTemplateConfiguration(); + cfg.accept(config); + return config; }); } @@ -115,7 +112,7 @@ public class MongoTestTemplate extends MongoTemplate { } public void flush(Class... entities) { - flush(Arrays.asList(entities).stream().map(this::getCollectionName).collect(Collectors.toList())); + flush(Arrays.stream(entities).map(this::getCollectionName).collect(Collectors.toList())); } public void flush(String... collections) { @@ -124,7 +121,7 @@ public class MongoTestTemplate extends MongoTemplate { public void flush(Object... objects) { - flush(Arrays.asList(objects).stream().map(it -> { + flush(Arrays.stream(objects).map(it -> { if (it instanceof String) { return (String) it; @@ -167,7 +164,7 @@ public class MongoTestTemplate extends MongoTemplate { Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> { - ArrayList execute = this.execute(collectionName, + List execute = this.execute(collectionName, coll -> coll .aggregate(List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted(indexName)))) .into(new ArrayList<>()));