Browse Source

Polishing.

Remove Field type. Refactor container to subclass MongoDBAtlasLocalContainer. Introduce wait/synchronization to avoid container crashes on create index + list search indexes.

See #4706
Original pull request: #4882
pull/4890/head
Mark Paluch 11 months ago
parent
commit
a53ea25de2
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 6
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java
  2. 85
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java
  3. 5
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java
  4. 6
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java
  5. 14
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java
  6. 8
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java
  7. 18
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java
  8. 12
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java
  9. 47
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java
  10. 5
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java
  11. 18
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java
  12. 34
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java
  13. 57
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java
  14. 40
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java
  15. 86
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java
  16. 17
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java

6
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java

@ -185,7 +185,7 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator { @@ -185,7 +185,7 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator {
Class<?> rawTargetType = computeTargetType(property); // target type before conversion
Class<?> targetType = converter.getTypeMapper().getWriteTargetTypeFor(rawTargetType); // conversion target type
if((rawTargetType.isPrimitive() || ClassUtils.isPrimitiveArray(rawTargetType)) && targetType == Object.class) {
if ((rawTargetType.isPrimitive() || ClassUtils.isPrimitiveArray(rawTargetType)) && targetType == Object.class) {
targetType = rawTargetType;
}
@ -338,8 +338,8 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator { @@ -338,8 +338,8 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator {
private String computePropertyFieldName(PersistentProperty<?> property) {
return property instanceof MongoPersistentProperty mongoPersistentProperty ?
mongoPersistentProperty.getFieldName() : property.getName();
return property instanceof MongoPersistentProperty mongoPersistentProperty ? mongoPersistentProperty.getFieldName()
: property.getName();
}
private boolean isRequiredProperty(PersistentProperty<?> property) {

85
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java

@ -16,15 +16,14 @@ @@ -16,15 +16,14 @@
package org.springframework.data.mongodb.core.aggregation;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.bson.BinaryVector;
import org.bson.Document;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.mapping.MongoVector;
@ -177,7 +176,7 @@ public class VectorSearchOperation implements AggregationOperation { @@ -177,7 +176,7 @@ public class VectorSearchOperation implements AggregationOperation {
* can't specify a number less than the number of documents to return (limit). This field is required if
* {@link #searchType(SearchType)} is {@link SearchType#ANN} or {@link SearchType#DEFAULT}.
*
* @param numCandidates
* @param numCandidates number of nearest neighbors to use during the search
* @return a new {@link VectorSearchOperation} with {@code numCandidates} applied.
*/
@Contract("_ -> new")
@ -338,20 +337,25 @@ public class VectorSearchOperation implements AggregationOperation { @@ -338,20 +337,25 @@ public class VectorSearchOperation implements AggregationOperation {
ENN
}
// A query path cannot only contain the name of the filed but may also hold additional information about the
// analyzer to use;
// "path": [ "names", "notes", { "value": "comments", "multi": "mySecondaryAnalyzer" } ]
// see: https://www.mongodb.com/docs/atlas/atlas-search/path-construction/#std-label-ref-path
/**
* Value object capturing query paths.
*/
public static class QueryPaths {
Set<QueryPath<?>> paths;
private final Set<QueryPath<?>> paths;
public static QueryPaths of(QueryPath<String> path) {
private QueryPaths(Set<QueryPath<?>> paths) {
this.paths = paths;
}
QueryPaths queryPaths = new QueryPaths();
queryPaths.paths = new LinkedHashSet<>(2);
queryPaths.paths.add(path);
return queryPaths;
/**
* Factory method to create {@link QueryPaths} from a single {@link QueryPath}.
*
* @param path
* @return a new {@link QueryPaths} instance.
*/
public static QueryPaths of(QueryPath<String> path) {
return new QueryPaths(Set.of(path));
}
Object getPathObject() {
@ -363,6 +367,12 @@ public class VectorSearchOperation implements AggregationOperation { @@ -363,6 +367,12 @@ public class VectorSearchOperation implements AggregationOperation {
}
}
/**
* Interface describing a query path contract. Query paths might be simple field names, wildcard paths, or
* multi-paths. paths.
*
* @param <T>
*/
public interface QueryPath<T> {
T value();
@ -370,14 +380,6 @@ public class VectorSearchOperation implements AggregationOperation { @@ -370,14 +380,6 @@ public class VectorSearchOperation implements AggregationOperation {
static QueryPath<String> path(String field) {
return new SimplePath(field);
}
static QueryPath<Map<String, Object>> wildcard(String field) {
return new WildcardPath(field);
}
static QueryPath<Map<String, Object>> multi(String field, String analyzer) {
return new MultiPath(field, analyzer);
}
}
public static class SimplePath implements QueryPath<String> {
@ -394,36 +396,9 @@ public class VectorSearchOperation implements AggregationOperation { @@ -394,36 +396,9 @@ public class VectorSearchOperation implements AggregationOperation {
}
}
public static class WildcardPath implements QueryPath<Map<String, Object>> {
String name;
public WildcardPath(String name) {
this.name = name;
}
@Override
public Map<String, Object> value() {
return Map.of("wildcard", name);
}
}
public static class MultiPath implements QueryPath<Map<String, Object>> {
String field;
String analyzer;
public MultiPath(String field, String analyzer) {
this.field = field;
this.analyzer = analyzer;
}
@Override
public Map<String, Object> value() {
return Map.of("value", field, "multi", analyzer);
}
}
/**
* Fluent API to configure a path on the VectorSearchOperation builder.
*/
public interface PathContributor {
/**
@ -436,6 +411,9 @@ public class VectorSearchOperation implements AggregationOperation { @@ -436,6 +411,9 @@ public class VectorSearchOperation implements AggregationOperation {
VectorContributor path(String path);
}
/**
* Fluent API to configure a vector on the VectorSearchOperation builder.
*/
public interface VectorContributor {
/**
@ -458,7 +436,7 @@ public class VectorSearchOperation implements AggregationOperation { @@ -458,7 +436,7 @@ public class VectorSearchOperation implements AggregationOperation {
* @return
*/
@Contract("_ -> this")
default LimitContributor vector(byte... vector) {
default LimitContributor vector(byte[] vector) {
return vector(BinaryVector.int8Vector(vector));
}
@ -510,6 +488,9 @@ public class VectorSearchOperation implements AggregationOperation { @@ -510,6 +488,9 @@ public class VectorSearchOperation implements AggregationOperation {
LimitContributor vector(Vector vector);
}
/**
* Fluent API to configure a limit on the VectorSearchOperation builder.
*/
public interface LimitContributor {
/**

5
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java

@ -15,7 +15,7 @@ @@ -15,7 +15,7 @@
*/
package org.springframework.data.mongodb.core.convert;
import static org.springframework.data.convert.ConverterBuilder.reading;
import static org.springframework.data.convert.ConverterBuilder.*;
import java.math.BigDecimal;
import java.math.BigInteger;
@ -47,6 +47,7 @@ import org.bson.types.Binary; @@ -47,6 +47,7 @@ import org.bson.types.Binary;
import org.bson.types.Code;
import org.bson.types.Decimal128;
import org.bson.types.ObjectId;
import org.springframework.core.convert.ConversionFailedException;
import org.springframework.core.convert.TypeDescriptor;
import org.springframework.core.convert.converter.ConditionalConverter;
@ -118,8 +119,6 @@ abstract class MongoConverters { @@ -118,8 +119,6 @@ abstract class MongoConverters {
converters.add(reading(BsonUndefined.class, Object.class, it -> null));
converters.add(reading(String.class, URI.class, URI::create).andWriting(URI::toString));
converters.add(ByteArrayConverterFactory.INSTANCE);
return converters;
}

6
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java

@ -33,14 +33,14 @@ public interface IndexOperations { @@ -33,14 +33,14 @@ public interface IndexOperations {
*
* @param indexDefinition must not be {@literal null}.
* @return the index name.
* @deprecated in favor of {@link #createIndex(IndexDefinition)}.
* @deprecated since 4.5, in favor of {@link #createIndex(IndexDefinition)}.
*/
@Deprecated(since = "4.5", forRemoval = true)
String ensureIndex(IndexDefinition indexDefinition);
/**
* Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity
* class. If not it will be created.
* Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity class.
* If not it will be created.
*
* @param indexDefinition must not be {@literal null}.
* @return the index name.

14
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java

@ -33,9 +33,23 @@ public interface ReactiveIndexOperations { @@ -33,9 +33,23 @@ public interface ReactiveIndexOperations {
*
* @param indexDefinition must not be {@literal null}.
* @return a {@link Mono} emitting the name of the index on completion.
* @deprecated since 4.5, in favor of {@link #createIndex(IndexDefinition)}.
*/
@Deprecated(since = "4.5", forRemoval = true)
Mono<String> ensureIndex(IndexDefinition indexDefinition);
/**
* Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity class.
* If not it will be created.
*
* @param indexDefinition must not be {@literal null}.
* @return the index name.
* @since 4.5
*/
default Mono<String> createIndex(IndexDefinition indexDefinition) {
return ensureIndex(indexDefinition);
}
/**
* Alters the index with given {@literal name}.
*

8
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java

@ -45,7 +45,7 @@ public interface SearchIndexDefinition { @@ -45,7 +45,7 @@ public interface SearchIndexDefinition {
* Returns the index document for this index without any potential entity context resolving field name mappings. The
* resulting document contains the index name, type and {@link #getDefinition(TypeInformation, MappingContext)
* definition}.
*
*
* @return never {@literal null}.
*/
default Document getRawIndexDocument() {
@ -74,12 +74,14 @@ public interface SearchIndexDefinition { @@ -74,12 +74,14 @@ public interface SearchIndexDefinition {
/**
* Returns the actual index definition for this index in the context of a potential entity to resolve field name
* mappings.
* mappings. Entity and context can be {@literal null} to create a generic index definition without applying field
* name mapping.
*
* @param entity can be {@literal null}.
* @param mappingContext
* @param mappingContext can be {@literal null}.
* @return never {@literal null}.
*/
Document getDefinition(@Nullable TypeInformation<?> entity,
@Nullable MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext);
}

18
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java

@ -27,8 +27,9 @@ import org.springframework.lang.Nullable; @@ -27,8 +27,9 @@ import org.springframework.lang.Nullable;
/**
* Index information for a MongoDB Search Index.
*
*
* @author Christoph Strobl
* @since 4.5
*/
public class SearchIndexInfo {
@ -42,14 +43,27 @@ public class SearchIndexInfo { @@ -42,14 +43,27 @@ public class SearchIndexInfo {
this.indexDefinition = Lazy.of(indexDefinition);
}
/**
* Parse a BSON document describing an index into a {@link SearchIndexInfo}.
*
* @param source BSON document describing the index.
* @return a new {@link SearchIndexInfo} instance.
*/
public static SearchIndexInfo parse(String source) {
return of(Document.parse(source));
}
/**
* Create an index from its BSON {@link Document} representation into a {@link SearchIndexInfo}.
*
* @param indexDocument BSON document describing the index.
* @return a new {@link SearchIndexInfo} instance.
*/
public static SearchIndexInfo of(Document indexDocument) {
Object id = indexDocument.get("id");
SearchIndexStatus status = SearchIndexStatus.valueOf(indexDocument.get("status", "DOES_NOT_EXIST"));
SearchIndexStatus status = SearchIndexStatus
.valueOf(indexDocument.get("status", SearchIndexStatus.DOES_NOT_EXIST.name()));
return new SearchIndexInfo(id, status, () -> readIndexDefinition(indexDocument));
}

12
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java

@ -27,17 +27,6 @@ import org.springframework.dao.DataAccessException; @@ -27,17 +27,6 @@ import org.springframework.dao.DataAccessException;
*/
public interface SearchIndexOperations {
/**
* Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class.
*
* @param indexDefinition must not be {@literal null}.
* @return the index name.
*/
// TODO: keep or just go with createIndex?
default String ensureIndex(SearchIndexDefinition indexDefinition) {
return createIndex(indexDefinition);
}
/**
* Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class.
*
@ -53,7 +42,6 @@ public interface SearchIndexOperations { @@ -53,7 +42,6 @@ public interface SearchIndexOperations {
*
* @param indexDefinition the index definition.
*/
// TODO: keep or remove since it does not work reliably?
void updateIndex(SearchIndexDefinition indexDefinition);
/**

47
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java

@ -1,27 +1,11 @@ @@ -1,27 +1,11 @@
/*
* Copyright 2024. the original author or authors.
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@ -153,12 +137,16 @@ public class VectorIndex implements SearchIndexDefinition { @@ -153,12 +137,16 @@ public class VectorIndex implements SearchIndexDefinition {
return "VectorIndex{" + "name='" + name + '\'' + ", fields=" + fields + ", type='" + getType() + '\'' + '}';
}
// /** instead of index info */
/**
* Parse the {@link Document} into a {@link VectorIndex}.
*/
static VectorIndex of(Document document) {
VectorIndex index = new VectorIndex(document.getString("name"));
String definitionKey = document.containsKey("latestDefinition") ? "latestDefinition" : "definition";
Document definition = document.get(definitionKey, Document.class);
for (Object entry : definition.get("fields", List.class)) {
if (entry instanceof Document field) {
if (field.get("type").equals("vector")) {
@ -195,7 +183,7 @@ public class VectorIndex implements SearchIndexDefinition { @@ -195,7 +183,7 @@ public class VectorIndex implements SearchIndexDefinition {
record VectorFilterField(String path, String type) implements SearchField {
}
record VectorIndexField(String path, String type, int dimensions, String similarity,
record VectorIndexField(String path, String type, int dimensions, @Nullable String similarity,
@Nullable String quantization) implements SearchField {
}
@ -313,6 +301,9 @@ public class VectorIndex implements SearchIndexDefinition { @@ -313,6 +301,9 @@ public class VectorIndex implements SearchIndexDefinition {
}
}
/**
* Similarity function used to calculate vector distance.
*/
public enum SimilarityFunction {
DOT_PRODUCT("dotProduct"), COSINE("cosine"), EUCLIDEAN("euclidean");
@ -328,10 +319,22 @@ public class VectorIndex implements SearchIndexDefinition { @@ -328,10 +319,22 @@ public class VectorIndex implements SearchIndexDefinition {
}
}
/** make it nullable */
/**
* Vector quantization. Quantization reduce vector sizes while preserving performance.
*/
public enum Quantization {
NONE("none"), SCALAR("scalar"), BINARY("binary");
NONE("none"),
/**
* Converting a float point into an integer.
*/
SCALAR("scalar"),
/**
* Converting a float point into a single bit.
*/
BINARY("binary");
final String quantizationName;

5
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java

@ -18,8 +18,6 @@ package org.springframework.data.mongodb.core.mapping; @@ -18,8 +18,6 @@ package org.springframework.data.mongodb.core.mapping;
import java.util.Date;
import java.util.regex.Pattern;
import org.bson.BinaryVector;
import org.bson.BsonBinary;
import org.bson.types.BSONTimestamp;
import org.bson.types.Binary;
import org.bson.types.Code;
@ -57,8 +55,7 @@ public enum FieldType { @@ -57,8 +55,7 @@ public enum FieldType {
INT32(15, Integer.class), //
TIMESTAMP(16, BSONTimestamp.class), //
INT64(17, Long.class), //
DECIMAL128(18, Decimal128.class),
VECTOR(5, BinaryVector.class);
DECIMAL128(18, Decimal128.class);
private final int bsonType;
private final Class<?> javaClass;

18
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java

@ -15,11 +15,13 @@ @@ -15,11 +15,13 @@
*/
package org.springframework.data.mongodb.core.aggregation;
import static org.assertj.core.api.Assertions.*;
import java.util.List;
import org.assertj.core.api.Assertions;
import org.bson.Document;
import org.junit.jupiter.api.Test;
import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType;
import org.springframework.data.mongodb.core.mapping.Field;
@ -27,6 +29,8 @@ import org.springframework.data.mongodb.core.query.Criteria; @@ -27,6 +29,8 @@ import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.util.aggregation.TestAggregationContext;
/**
* Unit tests for {@link VectorSearchOperation}.
*
* @author Christoph Strobl
*/
class VectorSearchOperationUnitTests {
@ -40,7 +44,7 @@ class VectorSearchOperationUnitTests { @@ -40,7 +44,7 @@ class VectorSearchOperationUnitTests {
void requiredArgs() {
List<Document> stages = SEARCH_OPERATION.toPipelineStages(Aggregation.DEFAULT_CONTEXT);
Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH));
assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH));
}
@Test // GH-4706
@ -53,7 +57,7 @@ class VectorSearchOperationUnitTests { @@ -53,7 +57,7 @@ class VectorSearchOperationUnitTests {
Document filter = new Document("$and",
List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975))));
Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch",
assertThat(stages).containsExactly(new Document("$vectorSearch",
new Document($VECTOR_SEARCH).append("exact", true).append("filter", filter).append("numCandidates", 150)));
}
@ -61,7 +65,7 @@ class VectorSearchOperationUnitTests { @@ -61,7 +65,7 @@ class VectorSearchOperationUnitTests {
void withScore() {
List<Document> stages = SEARCH_OPERATION.withSearchScore().toPipelineStages(Aggregation.DEFAULT_CONTEXT);
Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH),
assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH),
new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore"))));
}
@ -70,7 +74,7 @@ class VectorSearchOperationUnitTests { @@ -70,7 +74,7 @@ class VectorSearchOperationUnitTests {
List<Document> stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50))
.toPipelineStages(Aggregation.DEFAULT_CONTEXT);
Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH),
assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH),
new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore"))),
new Document("$match", new Document("score", new Document("$gt", 50))));
}
@ -80,7 +84,7 @@ class VectorSearchOperationUnitTests { @@ -80,7 +84,7 @@ class VectorSearchOperationUnitTests {
List<Document> stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50)).withSearchScore("s-c-o-r-e")
.toPipelineStages(Aggregation.DEFAULT_CONTEXT);
Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH),
assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH),
new Document("$addFields", new Document("s-c-o-r-e", new Document("$meta", "vectorSearchScore"))),
new Document("$match", new Document("s-c-o-r-e", new Document("$gt", 50))));
}
@ -95,7 +99,7 @@ class VectorSearchOperationUnitTests { @@ -95,7 +99,7 @@ class VectorSearchOperationUnitTests {
Document filter = new Document("$and",
List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975))));
Assertions.assertThat(stages)
assertThat(stages)
.containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("filter", filter)));
}

34
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java

@ -15,7 +15,7 @@ @@ -15,7 +15,7 @@
*/
package org.springframework.data.mongodb.core.aggregation;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.*;
import java.util.stream.IntStream;
import java.util.stream.Stream;
@ -28,15 +28,15 @@ import org.junit.jupiter.api.BeforeAll; @@ -28,15 +28,15 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType;
import org.springframework.data.mongodb.core.index.VectorIndex;
import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.mapping.FieldType;
import org.springframework.data.mongodb.core.mapping.MongoVector;
import org.springframework.data.mongodb.test.util.AtlasContainer;
import org.springframework.data.mongodb.test.util.MongoTestTemplate;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@ -44,16 +44,20 @@ import com.mongodb.client.MongoClient; @@ -44,16 +44,20 @@ import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
/**
* Integration tests using Vector Search and Vector Indexes through local MongoDB Atlas.
*
* @author Christoph Strobl
* @author Mark Paluch
*/
@Testcontainers(disabledWithoutDocker = true)
public class VectorSearchTests {
public static final String SCORE_FIELD = "vector-search-tests";
static final String COLLECTION_NAME = "collection-1";
private static final String SCORE_FIELD = "vector-search-tests";
private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch();
private static final String COLLECTION_NAME = "collection-1";
static MongoClient client;
static MongoTestTemplate template;
private static @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch();
@BeforeAll
static void beforeAll() throws InterruptedException {
@ -126,12 +130,12 @@ public class VectorSearchTests { @@ -126,12 +130,12 @@ public class VectorSearchTests {
return Stream.of(//
Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat32vector") //
.vector(new float[] { 0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f }) //
.vector(0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f) //
.limit(10)//
.numCandidates(20) //
.searchType(SearchType.ANN)),
Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat64vector") //
.vector(new double[] { 1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d }) //
.vector(1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d) //
.limit(10)//
.numCandidates(20) //
.searchType(SearchType.ANN)),
@ -160,8 +164,8 @@ public class VectorSearchTests { @@ -160,8 +164,8 @@ public class VectorSearchTests {
.addVector("float64vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5))
.addFilter("justSomeArgument");
template.searchIndexOps(WithVectorFields.class).ensureIndex(rawIndex);
template.searchIndexOps(WithVectorFields.class).ensureIndex(wrapperIndex);
template.searchIndexOps(WithVectorFields.class).createIndex(rawIndex);
template.searchIndexOps(WithVectorFields.class).createIndex(wrapperIndex);
template.awaitIndexCreation(WithVectorFields.class, rawIndex.getName());
template.awaitIndexCreation(WithVectorFields.class, wrapperIndex.getName());
@ -188,8 +192,7 @@ public class VectorSearchTests { @@ -188,8 +192,7 @@ public class VectorSearchTests {
Vector float32vector;
Vector float64vector;
@Field(targetType = FieldType.VECTOR) //
byte[] rawInt8vector;
BinaryVector rawInt8vector;
float[] rawFloat32vector;
double[] rawFloat64vector;
@ -199,15 +202,16 @@ public class VectorSearchTests { @@ -199,15 +202,16 @@ public class VectorSearchTests {
WithVectorFields instance = new WithVectorFields();
instance.id = "id-%s".formatted(offset);
instance.rawInt8vector = new byte[5];
instance.rawFloat32vector = new float[5];
instance.rawFloat64vector = new double[5];
byte[] int8 = new byte[5];
for (int i = 0; i < 5; i++) {
int v = i + offset;
instance.rawInt8vector[i] = (byte) v;
int8[i] = (byte) v;
}
instance.rawInt8vector = BinaryVector.int8Vector(int8);
if (offset == 0) {
instance.rawFloat32vector[0] = 0.0001f;
@ -227,7 +231,7 @@ public class VectorSearchTests { @@ -227,7 +231,7 @@ public class VectorSearchTests {
instance.justSomeArgument = offset;
instance.int8vector = MongoVector.of(BinaryVector.int8Vector(instance.rawInt8vector));
instance.int8vector = MongoVector.of(instance.rawInt8vector);
instance.float32vector = MongoVector.of(BinaryVector.floatVector(instance.rawFloat32vector));
instance.float64vector = Vector.of(instance.rawFloat64vector);

57
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java

@ -15,23 +15,10 @@ @@ -15,23 +15,10 @@
*/
package org.springframework.data.mongodb.core.convert;
import static java.time.ZoneId.systemDefault;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.data.mongodb.core.DocumentTestUtils.assertTypeHint;
import static org.springframework.data.mongodb.core.DocumentTestUtils.getAsDocument;
import static java.time.ZoneId.*;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.*;
import static org.springframework.data.mongodb.core.DocumentTestUtils.*;
import java.math.BigDecimal;
import java.math.BigInteger;
@ -40,30 +27,12 @@ import java.nio.ByteBuffer; @@ -40,30 +27,12 @@ import java.nio.ByteBuffer;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.EnumMap;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.UUID;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
import org.assertj.core.data.Percentage;
import org.bson.BinaryVector;
import org.bson.BsonDouble;
import org.bson.BsonUndefined;
import org.bson.types.Binary;
@ -81,6 +50,7 @@ import org.junit.jupiter.params.provider.ValueSource; @@ -81,6 +50,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.aop.framework.ProxyFactory;
import org.springframework.beans.ConversionNotSupportedException;
import org.springframework.beans.factory.annotation.Autowired;
@ -3380,18 +3350,6 @@ class MappingMongoConverterUnitTests { @@ -3380,18 +3350,6 @@ class MappingMongoConverterUnitTests {
assertThat(withVector.embeddings.toDoubleArray()).contains(1.1d, 2.2d, 3.3d);
}
@Test // GH-4706
void mapsByteArrayAsVectorWhenAnnotatedWithFieldTargetType() {
WithExplicitTargetTypes source = new WithExplicitTargetTypes();
source.asVector = new byte[] { 0, 1, 2 };
org.bson.Document target = new org.bson.Document();
converter.write(source, target);
assertThatNoException().isThrownBy(() -> target.get("asVector", BinaryVector.class));
}
@Test // GH-4706
void writesByteArrayAsIsIfNoFieldInstructionsGiven() {
@ -4070,9 +4028,6 @@ class MappingMongoConverterUnitTests { @@ -4070,9 +4028,6 @@ class MappingMongoConverterUnitTests {
@Field(targetType = FieldType.OBJECT_ID) //
Date dateAsObjectId;
@Field(targetType = FieldType.VECTOR) //
byte[] asVector;
}
static class WrapperAroundWithUnwrapped {

40
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java

@ -15,8 +15,8 @@ @@ -15,8 +15,8 @@
*/
package org.springframework.data.mongodb.core.index;
import static org.assertj.core.api.Assertions.assertThatRuntimeException;
import static org.awaitility.Awaitility.await;
import static org.assertj.core.api.Assertions.*;
import static org.awaitility.Awaitility.*;
import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
import java.util.List;
@ -27,6 +27,7 @@ import org.junit.jupiter.api.BeforeEach; @@ -27,6 +27,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction;
import org.springframework.data.mongodb.core.mapping.Field;
@ -34,6 +35,7 @@ import org.springframework.data.mongodb.test.util.AtlasContainer; @@ -34,6 +35,7 @@ import org.springframework.data.mongodb.test.util.AtlasContainer;
import org.springframework.data.mongodb.test.util.MongoTestTemplate;
import org.springframework.data.mongodb.test.util.MongoTestUtils;
import org.springframework.lang.Nullable;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@ -49,7 +51,7 @@ import com.mongodb.client.AggregateIterable; @@ -49,7 +51,7 @@ import com.mongodb.client.AggregateIterable;
@Testcontainers(disabledWithoutDocker = true)
class VectorIndexIntegrationTests {
private static @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch();
private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch();
MongoTestTemplate template = new MongoTestTemplate(cfg -> {
cfg.configureDatabaseFactory(ctx -> {
@ -82,7 +84,7 @@ class VectorIndexIntegrationTests { @@ -82,7 +84,7 @@ class VectorIndexIntegrationTests {
VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding",
builder -> builder.dimensions(1536).similarity(similarityFunction));
indexOps.ensureIndex(idx);
indexOps.createIndex(idx);
await().untilAsserted(() -> {
Document raw = readRawIndexInfo(idx.getName());
@ -101,7 +103,7 @@ class VectorIndexIntegrationTests { @@ -101,7 +103,7 @@ class VectorIndexIntegrationTests {
VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding",
builder -> builder.dimensions(1536).similarity("cosine"));
indexOps.ensureIndex(idx);
indexOps.createIndex(idx);
template.awaitIndexCreation(Movie.class, idx.getName());
@ -111,7 +113,7 @@ class VectorIndexIntegrationTests { @@ -111,7 +113,7 @@ class VectorIndexIntegrationTests {
}
@Test // GH-4706
void statusChanges() {
void statusChanges() throws InterruptedException {
String indexName = "vector_index";
assertThat(indexOps.status(indexName)).isEqualTo(SearchIndexStatus.DOES_NOT_EXIST);
@ -119,14 +121,17 @@ class VectorIndexIntegrationTests { @@ -119,14 +121,17 @@ class VectorIndexIntegrationTests {
VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding",
builder -> builder.dimensions(1536).similarity("cosine"));
indexOps.ensureIndex(idx);
indexOps.createIndex(idx);
// without synchronization, the container might crash.
Thread.sleep(500);
assertThat(indexOps.status(indexName)).isIn(SearchIndexStatus.PENDING, SearchIndexStatus.BUILDING,
SearchIndexStatus.READY);
}
@Test // GH-4706
void exists() {
void exists() throws InterruptedException {
String indexName = "vector_index";
assertThat(indexOps.exists(indexName)).isFalse();
@ -134,19 +139,25 @@ class VectorIndexIntegrationTests { @@ -134,19 +139,25 @@ class VectorIndexIntegrationTests {
VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding",
builder -> builder.dimensions(1536).similarity("cosine"));
indexOps.ensureIndex(idx);
indexOps.createIndex(idx);
// without synchronization, the container might crash.
Thread.sleep(500);
assertThat(indexOps.exists(indexName)).isTrue();
}
@Test // GH-4706
void updatesVectorIndex() {
void updatesVectorIndex() throws InterruptedException {
String indexName = "vector_index";
VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding",
builder -> builder.dimensions(1536).similarity("cosine"));
indexOps.ensureIndex(idx);
indexOps.createIndex(idx);
// without synchronization, the container might crash.
Thread.sleep(500);
await().untilAsserted(() -> {
Document raw = readRawIndexInfo(idx.getName());
@ -166,13 +177,16 @@ class VectorIndexIntegrationTests { @@ -166,13 +177,16 @@ class VectorIndexIntegrationTests {
}
@Test // GH-4706
void createsVectorIndexWithFilters() {
void createsVectorIndexWithFilters() throws InterruptedException {
VectorIndex idx = new VectorIndex("vector_index")
.addVector("plotEmbedding", builder -> builder.dimensions(1536).cosine()).addFilter("description")
.addFilter("year");
indexOps.ensureIndex(idx);
indexOps.createIndex(idx);
// without synchronization, the container might crash.
Thread.sleep(500);
await().untilAsserted(() -> {
Document raw = readRawIndexInfo(idx.getName());

86
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java

@ -15,96 +15,44 @@ @@ -15,96 +15,44 @@
*/
package org.springframework.data.mongodb.test.util;
import java.util.List;
import org.bson.Document;
import org.springframework.core.env.StandardEnvironment;
import org.springframework.data.util.Lazy;
import org.springframework.util.StringUtils;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.DockerHealthcheckWaitStrategy;
import org.testcontainers.containers.wait.strategy.WaitStrategy;
import org.testcontainers.utility.DockerImageName;
import com.mongodb.ConnectionString;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import org.testcontainers.mongodb.MongoDBAtlasLocalContainer;
import org.testcontainers.utility.DockerImageName;
/**
* Extension to MongoDBAtlasLocalContainer.
*
* @author Christoph Strobl
*/
public class AtlasContainer extends GenericContainer<AtlasContainer> {
public class AtlasContainer extends MongoDBAtlasLocalContainer {
private static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("mongodb/mongodb-atlas-local");
private static final String DEFAULT_TAG = "latest";
private static final String MONGODB_DATABASE_NAME_DEFAULT = "test";
private static final String READY_DB = "__db_ready_check";
private final Lazy<MongoClient> client;
private static final String DEFAULT_TAG = "8.0.0";
private static final String LATEST = "latest";
private AtlasContainer(String dockerImageName) {
super(DockerImageName.parse(dockerImageName));
}
private AtlasContainer(DockerImageName dockerImageName) {
super(dockerImageName);
}
public static AtlasContainer bestMatch() {
return tagged(new StandardEnvironment().getProperty("mongodb.atlas.version", DEFAULT_TAG));
}
public static AtlasContainer latest() {
return tagged(DEFAULT_TAG);
return tagged(LATEST);
}
public static AtlasContainer version8() {
return tagged("8.0.0");
return tagged(DEFAULT_TAG);
}
public static AtlasContainer tagged(String tag) {
return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag));
}
public AtlasContainer(String dockerImageName) {
this(DockerImageName.parse(dockerImageName));
}
public AtlasContainer(DockerImageName dockerImageName) {
super(dockerImageName);
dockerImageName.assertCompatibleWith(DEFAULT_IMAGE_NAME);
setExposedPorts(List.of(27017));
client = Lazy.of(() -> MongoTestUtils.client(new ConnectionString(getConnectionString())));
}
public String getConnectionString() {
return getConnectionString(MONGODB_DATABASE_NAME_DEFAULT);
}
/**
* Gets a connection string url.
*
* @return a connection url pointing to a mongodb instance
*/
public String getConnectionString(String database) {
return String.format("mongodb://%s:%d/%s?directConnection=true", getHost(), getMappedPort(27017),
StringUtils.hasText(database) ? database : MONGODB_DATABASE_NAME_DEFAULT);
}
@Override
public boolean isHealthy() {
MongoClient mongoClient = client.get();
MongoCollection<Document> ready = MongoTestUtils.createOrReplaceCollection(READY_DB, "ready", mongoClient);
boolean isReady = false;
try {
ready.aggregate(List.of(new Document("$listSearchIndexes", new Document()))).first();
isReady = true;
} catch (Exception e) {
// ok so the search service is not ready yet - sigh
}
if (isReady) {
mongoClient.getDatabase(READY_DB).drop();
mongoClient.close();
}
return isReady;
}
@Override
protected WaitStrategy getWaitStrategy() {
return new DockerHealthcheckWaitStrategy();
}
}

17
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java

@ -63,14 +63,11 @@ public class MongoTestTemplate extends MongoTemplate { @@ -63,14 +63,11 @@ public class MongoTestTemplate extends MongoTemplate {
public MongoTestTemplate(Consumer<MongoTestTemplateConfiguration> cfg) {
this(new Supplier<MongoTestTemplateConfiguration>() {
@Override
public MongoTestTemplateConfiguration get() {
this(() -> {
MongoTestTemplateConfiguration config = new MongoTestTemplateConfiguration();
cfg.accept(config);
return config;
}
MongoTestTemplateConfiguration config = new MongoTestTemplateConfiguration();
cfg.accept(config);
return config;
});
}
@ -115,7 +112,7 @@ public class MongoTestTemplate extends MongoTemplate { @@ -115,7 +112,7 @@ public class MongoTestTemplate extends MongoTemplate {
}
public void flush(Class<?>... entities) {
flush(Arrays.asList(entities).stream().map(this::getCollectionName).collect(Collectors.toList()));
flush(Arrays.stream(entities).map(this::getCollectionName).collect(Collectors.toList()));
}
public void flush(String... collections) {
@ -124,7 +121,7 @@ public class MongoTestTemplate extends MongoTemplate { @@ -124,7 +121,7 @@ public class MongoTestTemplate extends MongoTemplate {
public void flush(Object... objects) {
flush(Arrays.asList(objects).stream().map(it -> {
flush(Arrays.stream(objects).map(it -> {
if (it instanceof String) {
return (String) it;
@ -167,7 +164,7 @@ public class MongoTestTemplate extends MongoTemplate { @@ -167,7 +164,7 @@ public class MongoTestTemplate extends MongoTemplate {
Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> {
ArrayList<Document> execute = this.execute(collectionName,
List<Document> execute = this.execute(collectionName,
coll -> coll
.aggregate(List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted(indexName))))
.into(new ArrayList<>()));

Loading…
Cancel
Save