Browse Source

Polishing.

Remove Field type. Refactor container to subclass MongoDBAtlasLocalContainer. Introduce wait/synchronization to avoid container crashes on create index + list search indexes.

See #4706
Original pull request: #4882
pull/4890/head
Mark Paluch 11 months ago
parent
commit
a53ea25de2
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 4
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java
  2. 85
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java
  3. 5
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java
  4. 6
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java
  5. 14
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java
  6. 6
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java
  7. 16
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java
  8. 12
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java
  9. 47
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java
  10. 5
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java
  11. 18
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java
  12. 34
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java
  13. 57
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java
  14. 40
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java
  15. 86
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java
  16. 11
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java

4
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java

@ -338,8 +338,8 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator {
private String computePropertyFieldName(PersistentProperty<?> property) { private String computePropertyFieldName(PersistentProperty<?> property) {
return property instanceof MongoPersistentProperty mongoPersistentProperty ? return property instanceof MongoPersistentProperty mongoPersistentProperty ? mongoPersistentProperty.getFieldName()
mongoPersistentProperty.getFieldName() : property.getName(); : property.getName();
} }
private boolean isRequiredProperty(PersistentProperty<?> property) { private boolean isRequiredProperty(PersistentProperty<?> property) {

85
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java

@ -16,15 +16,14 @@
package org.springframework.data.mongodb.core.aggregation; package org.springframework.data.mongodb.core.aggregation;
import java.util.Arrays; import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.bson.BinaryVector; import org.bson.BinaryVector;
import org.bson.Document; import org.bson.Document;
import org.springframework.data.domain.Limit; import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Vector; import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.mapping.MongoVector; import org.springframework.data.mongodb.core.mapping.MongoVector;
@ -177,7 +176,7 @@ public class VectorSearchOperation implements AggregationOperation {
* can't specify a number less than the number of documents to return (limit). This field is required if * can't specify a number less than the number of documents to return (limit). This field is required if
* {@link #searchType(SearchType)} is {@link SearchType#ANN} or {@link SearchType#DEFAULT}. * {@link #searchType(SearchType)} is {@link SearchType#ANN} or {@link SearchType#DEFAULT}.
* *
* @param numCandidates * @param numCandidates number of nearest neighbors to use during the search
* @return a new {@link VectorSearchOperation} with {@code numCandidates} applied. * @return a new {@link VectorSearchOperation} with {@code numCandidates} applied.
*/ */
@Contract("_ -> new") @Contract("_ -> new")
@ -338,20 +337,25 @@ public class VectorSearchOperation implements AggregationOperation {
ENN ENN
} }
// A query path cannot only contain the name of the filed but may also hold additional information about the /**
// analyzer to use; * Value object capturing query paths.
// "path": [ "names", "notes", { "value": "comments", "multi": "mySecondaryAnalyzer" } ] */
// see: https://www.mongodb.com/docs/atlas/atlas-search/path-construction/#std-label-ref-path
public static class QueryPaths { public static class QueryPaths {
Set<QueryPath<?>> paths; private final Set<QueryPath<?>> paths;
public static QueryPaths of(QueryPath<String> path) { private QueryPaths(Set<QueryPath<?>> paths) {
this.paths = paths;
}
QueryPaths queryPaths = new QueryPaths(); /**
queryPaths.paths = new LinkedHashSet<>(2); * Factory method to create {@link QueryPaths} from a single {@link QueryPath}.
queryPaths.paths.add(path); *
return queryPaths; * @param path
* @return a new {@link QueryPaths} instance.
*/
public static QueryPaths of(QueryPath<String> path) {
return new QueryPaths(Set.of(path));
} }
Object getPathObject() { Object getPathObject() {
@ -363,6 +367,12 @@ public class VectorSearchOperation implements AggregationOperation {
} }
} }
/**
* Interface describing a query path contract. Query paths might be simple field names, wildcard paths, or
* multi-paths. paths.
*
* @param <T>
*/
public interface QueryPath<T> { public interface QueryPath<T> {
T value(); T value();
@ -370,14 +380,6 @@ public class VectorSearchOperation implements AggregationOperation {
static QueryPath<String> path(String field) { static QueryPath<String> path(String field) {
return new SimplePath(field); return new SimplePath(field);
} }
static QueryPath<Map<String, Object>> wildcard(String field) {
return new WildcardPath(field);
}
static QueryPath<Map<String, Object>> multi(String field, String analyzer) {
return new MultiPath(field, analyzer);
}
} }
public static class SimplePath implements QueryPath<String> { public static class SimplePath implements QueryPath<String> {
@ -394,36 +396,9 @@ public class VectorSearchOperation implements AggregationOperation {
} }
} }
public static class WildcardPath implements QueryPath<Map<String, Object>> { /**
* Fluent API to configure a path on the VectorSearchOperation builder.
String name; */
public WildcardPath(String name) {
this.name = name;
}
@Override
public Map<String, Object> value() {
return Map.of("wildcard", name);
}
}
public static class MultiPath implements QueryPath<Map<String, Object>> {
String field;
String analyzer;
public MultiPath(String field, String analyzer) {
this.field = field;
this.analyzer = analyzer;
}
@Override
public Map<String, Object> value() {
return Map.of("value", field, "multi", analyzer);
}
}
public interface PathContributor { public interface PathContributor {
/** /**
@ -436,6 +411,9 @@ public class VectorSearchOperation implements AggregationOperation {
VectorContributor path(String path); VectorContributor path(String path);
} }
/**
* Fluent API to configure a vector on the VectorSearchOperation builder.
*/
public interface VectorContributor { public interface VectorContributor {
/** /**
@ -458,7 +436,7 @@ public class VectorSearchOperation implements AggregationOperation {
* @return * @return
*/ */
@Contract("_ -> this") @Contract("_ -> this")
default LimitContributor vector(byte... vector) { default LimitContributor vector(byte[] vector) {
return vector(BinaryVector.int8Vector(vector)); return vector(BinaryVector.int8Vector(vector));
} }
@ -510,6 +488,9 @@ public class VectorSearchOperation implements AggregationOperation {
LimitContributor vector(Vector vector); LimitContributor vector(Vector vector);
} }
/**
* Fluent API to configure a limit on the VectorSearchOperation builder.
*/
public interface LimitContributor { public interface LimitContributor {
/** /**

5
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java

@ -15,7 +15,7 @@
*/ */
package org.springframework.data.mongodb.core.convert; package org.springframework.data.mongodb.core.convert;
import static org.springframework.data.convert.ConverterBuilder.reading; import static org.springframework.data.convert.ConverterBuilder.*;
import java.math.BigDecimal; import java.math.BigDecimal;
import java.math.BigInteger; import java.math.BigInteger;
@ -47,6 +47,7 @@ import org.bson.types.Binary;
import org.bson.types.Code; import org.bson.types.Code;
import org.bson.types.Decimal128; import org.bson.types.Decimal128;
import org.bson.types.ObjectId; import org.bson.types.ObjectId;
import org.springframework.core.convert.ConversionFailedException; import org.springframework.core.convert.ConversionFailedException;
import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.TypeDescriptor;
import org.springframework.core.convert.converter.ConditionalConverter; import org.springframework.core.convert.converter.ConditionalConverter;
@ -118,8 +119,6 @@ abstract class MongoConverters {
converters.add(reading(BsonUndefined.class, Object.class, it -> null)); converters.add(reading(BsonUndefined.class, Object.class, it -> null));
converters.add(reading(String.class, URI.class, URI::create).andWriting(URI::toString)); converters.add(reading(String.class, URI.class, URI::create).andWriting(URI::toString));
converters.add(ByteArrayConverterFactory.INSTANCE);
return converters; return converters;
} }

6
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java

@ -33,14 +33,14 @@ public interface IndexOperations {
* *
* @param indexDefinition must not be {@literal null}. * @param indexDefinition must not be {@literal null}.
* @return the index name. * @return the index name.
* @deprecated in favor of {@link #createIndex(IndexDefinition)}. * @deprecated since 4.5, in favor of {@link #createIndex(IndexDefinition)}.
*/ */
@Deprecated(since = "4.5", forRemoval = true) @Deprecated(since = "4.5", forRemoval = true)
String ensureIndex(IndexDefinition indexDefinition); String ensureIndex(IndexDefinition indexDefinition);
/** /**
* Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity * Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity class.
* class. If not it will be created. * If not it will be created.
* *
* @param indexDefinition must not be {@literal null}. * @param indexDefinition must not be {@literal null}.
* @return the index name. * @return the index name.

14
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java

@ -33,9 +33,23 @@ public interface ReactiveIndexOperations {
* *
* @param indexDefinition must not be {@literal null}. * @param indexDefinition must not be {@literal null}.
* @return a {@link Mono} emitting the name of the index on completion. * @return a {@link Mono} emitting the name of the index on completion.
* @deprecated since 4.5, in favor of {@link #createIndex(IndexDefinition)}.
*/ */
@Deprecated(since = "4.5", forRemoval = true)
Mono<String> ensureIndex(IndexDefinition indexDefinition); Mono<String> ensureIndex(IndexDefinition indexDefinition);
/**
* Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity class.
* If not it will be created.
*
* @param indexDefinition must not be {@literal null}.
* @return the index name.
* @since 4.5
*/
default Mono<String> createIndex(IndexDefinition indexDefinition) {
return ensureIndex(indexDefinition);
}
/** /**
* Alters the index with given {@literal name}. * Alters the index with given {@literal name}.
* *

6
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java

@ -74,12 +74,14 @@ public interface SearchIndexDefinition {
/** /**
* Returns the actual index definition for this index in the context of a potential entity to resolve field name * Returns the actual index definition for this index in the context of a potential entity to resolve field name
* mappings. * mappings. Entity and context can be {@literal null} to create a generic index definition without applying field
* name mapping.
* *
* @param entity can be {@literal null}. * @param entity can be {@literal null}.
* @param mappingContext * @param mappingContext can be {@literal null}.
* @return never {@literal null}. * @return never {@literal null}.
*/ */
Document getDefinition(@Nullable TypeInformation<?> entity, Document getDefinition(@Nullable TypeInformation<?> entity,
@Nullable MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext); @Nullable MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext);
} }

16
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java

@ -29,6 +29,7 @@ import org.springframework.lang.Nullable;
* Index information for a MongoDB Search Index. * Index information for a MongoDB Search Index.
* *
* @author Christoph Strobl * @author Christoph Strobl
* @since 4.5
*/ */
public class SearchIndexInfo { public class SearchIndexInfo {
@ -42,14 +43,27 @@ public class SearchIndexInfo {
this.indexDefinition = Lazy.of(indexDefinition); this.indexDefinition = Lazy.of(indexDefinition);
} }
/**
* Parse a BSON document describing an index into a {@link SearchIndexInfo}.
*
* @param source BSON document describing the index.
* @return a new {@link SearchIndexInfo} instance.
*/
public static SearchIndexInfo parse(String source) { public static SearchIndexInfo parse(String source) {
return of(Document.parse(source)); return of(Document.parse(source));
} }
/**
* Create an index from its BSON {@link Document} representation into a {@link SearchIndexInfo}.
*
* @param indexDocument BSON document describing the index.
* @return a new {@link SearchIndexInfo} instance.
*/
public static SearchIndexInfo of(Document indexDocument) { public static SearchIndexInfo of(Document indexDocument) {
Object id = indexDocument.get("id"); Object id = indexDocument.get("id");
SearchIndexStatus status = SearchIndexStatus.valueOf(indexDocument.get("status", "DOES_NOT_EXIST")); SearchIndexStatus status = SearchIndexStatus
.valueOf(indexDocument.get("status", SearchIndexStatus.DOES_NOT_EXIST.name()));
return new SearchIndexInfo(id, status, () -> readIndexDefinition(indexDocument)); return new SearchIndexInfo(id, status, () -> readIndexDefinition(indexDocument));
} }

12
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java

@ -27,17 +27,6 @@ import org.springframework.dao.DataAccessException;
*/ */
public interface SearchIndexOperations { public interface SearchIndexOperations {
/**
* Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class.
*
* @param indexDefinition must not be {@literal null}.
* @return the index name.
*/
// TODO: keep or just go with createIndex?
default String ensureIndex(SearchIndexDefinition indexDefinition) {
return createIndex(indexDefinition);
}
/** /**
* Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class. * Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class.
* *
@ -53,7 +42,6 @@ public interface SearchIndexOperations {
* *
* @param indexDefinition the index definition. * @param indexDefinition the index definition.
*/ */
// TODO: keep or remove since it does not work reliably?
void updateIndex(SearchIndexDefinition indexDefinition); void updateIndex(SearchIndexDefinition indexDefinition);
/** /**

47
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java

@ -1,27 +1,11 @@
/* /*
* Copyright 2024. the original author or authors. * Copyright 2024-2025 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
@ -153,12 +137,16 @@ public class VectorIndex implements SearchIndexDefinition {
return "VectorIndex{" + "name='" + name + '\'' + ", fields=" + fields + ", type='" + getType() + '\'' + '}'; return "VectorIndex{" + "name='" + name + '\'' + ", fields=" + fields + ", type='" + getType() + '\'' + '}';
} }
// /** instead of index info */ /**
* Parse the {@link Document} into a {@link VectorIndex}.
*/
static VectorIndex of(Document document) { static VectorIndex of(Document document) {
VectorIndex index = new VectorIndex(document.getString("name")); VectorIndex index = new VectorIndex(document.getString("name"));
String definitionKey = document.containsKey("latestDefinition") ? "latestDefinition" : "definition"; String definitionKey = document.containsKey("latestDefinition") ? "latestDefinition" : "definition";
Document definition = document.get(definitionKey, Document.class); Document definition = document.get(definitionKey, Document.class);
for (Object entry : definition.get("fields", List.class)) { for (Object entry : definition.get("fields", List.class)) {
if (entry instanceof Document field) { if (entry instanceof Document field) {
if (field.get("type").equals("vector")) { if (field.get("type").equals("vector")) {
@ -195,7 +183,7 @@ public class VectorIndex implements SearchIndexDefinition {
record VectorFilterField(String path, String type) implements SearchField { record VectorFilterField(String path, String type) implements SearchField {
} }
record VectorIndexField(String path, String type, int dimensions, String similarity, record VectorIndexField(String path, String type, int dimensions, @Nullable String similarity,
@Nullable String quantization) implements SearchField { @Nullable String quantization) implements SearchField {
} }
@ -313,6 +301,9 @@ public class VectorIndex implements SearchIndexDefinition {
} }
} }
/**
* Similarity function used to calculate vector distance.
*/
public enum SimilarityFunction { public enum SimilarityFunction {
DOT_PRODUCT("dotProduct"), COSINE("cosine"), EUCLIDEAN("euclidean"); DOT_PRODUCT("dotProduct"), COSINE("cosine"), EUCLIDEAN("euclidean");
@ -328,10 +319,22 @@ public class VectorIndex implements SearchIndexDefinition {
} }
} }
/** make it nullable */ /**
* Vector quantization. Quantization reduce vector sizes while preserving performance.
*/
public enum Quantization { public enum Quantization {
NONE("none"), SCALAR("scalar"), BINARY("binary"); NONE("none"),
/**
* Converting a float point into an integer.
*/
SCALAR("scalar"),
/**
* Converting a float point into a single bit.
*/
BINARY("binary");
final String quantizationName; final String quantizationName;

5
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java

@ -18,8 +18,6 @@ package org.springframework.data.mongodb.core.mapping;
import java.util.Date; import java.util.Date;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import org.bson.BinaryVector;
import org.bson.BsonBinary;
import org.bson.types.BSONTimestamp; import org.bson.types.BSONTimestamp;
import org.bson.types.Binary; import org.bson.types.Binary;
import org.bson.types.Code; import org.bson.types.Code;
@ -57,8 +55,7 @@ public enum FieldType {
INT32(15, Integer.class), // INT32(15, Integer.class), //
TIMESTAMP(16, BSONTimestamp.class), // TIMESTAMP(16, BSONTimestamp.class), //
INT64(17, Long.class), // INT64(17, Long.class), //
DECIMAL128(18, Decimal128.class), DECIMAL128(18, Decimal128.class);
VECTOR(5, BinaryVector.class);
private final int bsonType; private final int bsonType;
private final Class<?> javaClass; private final Class<?> javaClass;

18
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java

@ -15,11 +15,13 @@
*/ */
package org.springframework.data.mongodb.core.aggregation; package org.springframework.data.mongodb.core.aggregation;
import static org.assertj.core.api.Assertions.*;
import java.util.List; import java.util.List;
import org.assertj.core.api.Assertions;
import org.bson.Document; import org.bson.Document;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType;
import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.mapping.Field;
@ -27,6 +29,8 @@ import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.util.aggregation.TestAggregationContext; import org.springframework.data.mongodb.util.aggregation.TestAggregationContext;
/** /**
* Unit tests for {@link VectorSearchOperation}.
*
* @author Christoph Strobl * @author Christoph Strobl
*/ */
class VectorSearchOperationUnitTests { class VectorSearchOperationUnitTests {
@ -40,7 +44,7 @@ class VectorSearchOperationUnitTests {
void requiredArgs() { void requiredArgs() {
List<Document> stages = SEARCH_OPERATION.toPipelineStages(Aggregation.DEFAULT_CONTEXT); List<Document> stages = SEARCH_OPERATION.toPipelineStages(Aggregation.DEFAULT_CONTEXT);
Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH)); assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH));
} }
@Test // GH-4706 @Test // GH-4706
@ -53,7 +57,7 @@ class VectorSearchOperationUnitTests {
Document filter = new Document("$and", Document filter = new Document("$and",
List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975)))); List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975))));
Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", assertThat(stages).containsExactly(new Document("$vectorSearch",
new Document($VECTOR_SEARCH).append("exact", true).append("filter", filter).append("numCandidates", 150))); new Document($VECTOR_SEARCH).append("exact", true).append("filter", filter).append("numCandidates", 150)));
} }
@ -61,7 +65,7 @@ class VectorSearchOperationUnitTests {
void withScore() { void withScore() {
List<Document> stages = SEARCH_OPERATION.withSearchScore().toPipelineStages(Aggregation.DEFAULT_CONTEXT); List<Document> stages = SEARCH_OPERATION.withSearchScore().toPipelineStages(Aggregation.DEFAULT_CONTEXT);
Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH),
new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore")))); new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore"))));
} }
@ -70,7 +74,7 @@ class VectorSearchOperationUnitTests {
List<Document> stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50)) List<Document> stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50))
.toPipelineStages(Aggregation.DEFAULT_CONTEXT); .toPipelineStages(Aggregation.DEFAULT_CONTEXT);
Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH),
new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore"))), new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore"))),
new Document("$match", new Document("score", new Document("$gt", 50)))); new Document("$match", new Document("score", new Document("$gt", 50))));
} }
@ -80,7 +84,7 @@ class VectorSearchOperationUnitTests {
List<Document> stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50)).withSearchScore("s-c-o-r-e") List<Document> stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50)).withSearchScore("s-c-o-r-e")
.toPipelineStages(Aggregation.DEFAULT_CONTEXT); .toPipelineStages(Aggregation.DEFAULT_CONTEXT);
Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH),
new Document("$addFields", new Document("s-c-o-r-e", new Document("$meta", "vectorSearchScore"))), new Document("$addFields", new Document("s-c-o-r-e", new Document("$meta", "vectorSearchScore"))),
new Document("$match", new Document("s-c-o-r-e", new Document("$gt", 50)))); new Document("$match", new Document("s-c-o-r-e", new Document("$gt", 50))));
} }
@ -95,7 +99,7 @@ class VectorSearchOperationUnitTests {
Document filter = new Document("$and", Document filter = new Document("$and",
List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975)))); List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975))));
Assertions.assertThat(stages) assertThat(stages)
.containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("filter", filter))); .containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("filter", filter)));
} }

34
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java

@ -15,7 +15,7 @@
*/ */
package org.springframework.data.mongodb.core.aggregation; package org.springframework.data.mongodb.core.aggregation;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.*;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -28,15 +28,15 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.data.domain.Vector; import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType;
import org.springframework.data.mongodb.core.index.VectorIndex; import org.springframework.data.mongodb.core.index.VectorIndex;
import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.mapping.FieldType;
import org.springframework.data.mongodb.core.mapping.MongoVector; import org.springframework.data.mongodb.core.mapping.MongoVector;
import org.springframework.data.mongodb.test.util.AtlasContainer; import org.springframework.data.mongodb.test.util.AtlasContainer;
import org.springframework.data.mongodb.test.util.MongoTestTemplate; import org.springframework.data.mongodb.test.util.MongoTestTemplate;
import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.junit.jupiter.Testcontainers;
@ -44,16 +44,20 @@ import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients; import com.mongodb.client.MongoClients;
/** /**
* Integration tests using Vector Search and Vector Indexes through local MongoDB Atlas.
*
* @author Christoph Strobl * @author Christoph Strobl
* @author Mark Paluch
*/ */
@Testcontainers(disabledWithoutDocker = true) @Testcontainers(disabledWithoutDocker = true)
public class VectorSearchTests { public class VectorSearchTests {
public static final String SCORE_FIELD = "vector-search-tests"; private static final String SCORE_FIELD = "vector-search-tests";
static final String COLLECTION_NAME = "collection-1"; private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch();
private static final String COLLECTION_NAME = "collection-1";
static MongoClient client; static MongoClient client;
static MongoTestTemplate template; static MongoTestTemplate template;
private static @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch();
@BeforeAll @BeforeAll
static void beforeAll() throws InterruptedException { static void beforeAll() throws InterruptedException {
@ -126,12 +130,12 @@ public class VectorSearchTests {
return Stream.of(// return Stream.of(//
Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat32vector") // Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat32vector") //
.vector(new float[] { 0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f }) // .vector(0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f) //
.limit(10)// .limit(10)//
.numCandidates(20) // .numCandidates(20) //
.searchType(SearchType.ANN)), .searchType(SearchType.ANN)),
Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat64vector") // Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat64vector") //
.vector(new double[] { 1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d }) // .vector(1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d) //
.limit(10)// .limit(10)//
.numCandidates(20) // .numCandidates(20) //
.searchType(SearchType.ANN)), .searchType(SearchType.ANN)),
@ -160,8 +164,8 @@ public class VectorSearchTests {
.addVector("float64vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)) .addVector("float64vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5))
.addFilter("justSomeArgument"); .addFilter("justSomeArgument");
template.searchIndexOps(WithVectorFields.class).ensureIndex(rawIndex); template.searchIndexOps(WithVectorFields.class).createIndex(rawIndex);
template.searchIndexOps(WithVectorFields.class).ensureIndex(wrapperIndex); template.searchIndexOps(WithVectorFields.class).createIndex(wrapperIndex);
template.awaitIndexCreation(WithVectorFields.class, rawIndex.getName()); template.awaitIndexCreation(WithVectorFields.class, rawIndex.getName());
template.awaitIndexCreation(WithVectorFields.class, wrapperIndex.getName()); template.awaitIndexCreation(WithVectorFields.class, wrapperIndex.getName());
@ -188,8 +192,7 @@ public class VectorSearchTests {
Vector float32vector; Vector float32vector;
Vector float64vector; Vector float64vector;
@Field(targetType = FieldType.VECTOR) // BinaryVector rawInt8vector;
byte[] rawInt8vector;
float[] rawFloat32vector; float[] rawFloat32vector;
double[] rawFloat64vector; double[] rawFloat64vector;
@ -199,15 +202,16 @@ public class VectorSearchTests {
WithVectorFields instance = new WithVectorFields(); WithVectorFields instance = new WithVectorFields();
instance.id = "id-%s".formatted(offset); instance.id = "id-%s".formatted(offset);
instance.rawInt8vector = new byte[5];
instance.rawFloat32vector = new float[5]; instance.rawFloat32vector = new float[5];
instance.rawFloat64vector = new double[5]; instance.rawFloat64vector = new double[5];
byte[] int8 = new byte[5];
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
int v = i + offset; int v = i + offset;
instance.rawInt8vector[i] = (byte) v; int8[i] = (byte) v;
} }
instance.rawInt8vector = BinaryVector.int8Vector(int8);
if (offset == 0) { if (offset == 0) {
instance.rawFloat32vector[0] = 0.0001f; instance.rawFloat32vector[0] = 0.0001f;
@ -227,7 +231,7 @@ public class VectorSearchTests {
instance.justSomeArgument = offset; instance.justSomeArgument = offset;
instance.int8vector = MongoVector.of(BinaryVector.int8Vector(instance.rawInt8vector)); instance.int8vector = MongoVector.of(instance.rawInt8vector);
instance.float32vector = MongoVector.of(BinaryVector.floatVector(instance.rawFloat32vector)); instance.float32vector = MongoVector.of(BinaryVector.floatVector(instance.rawFloat32vector));
instance.float64vector = Vector.of(instance.rawFloat64vector); instance.float64vector = Vector.of(instance.rawFloat64vector);

57
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java

@ -15,23 +15,10 @@
*/ */
package org.springframework.data.mongodb.core.convert; package org.springframework.data.mongodb.core.convert;
import static java.time.ZoneId.systemDefault; import static java.time.ZoneId.*;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.Mockito.*;
import static org.assertj.core.api.Assertions.assertThatNoException; import static org.springframework.data.mongodb.core.DocumentTestUtils.*;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.data.mongodb.core.DocumentTestUtils.assertTypeHint;
import static org.springframework.data.mongodb.core.DocumentTestUtils.getAsDocument;
import java.math.BigDecimal; import java.math.BigDecimal;
import java.math.BigInteger; import java.math.BigInteger;
@ -40,30 +27,12 @@ import java.nio.ByteBuffer;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.ArrayList; import java.util.*;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.EnumMap;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.UUID;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.assertj.core.data.Percentage; import org.assertj.core.data.Percentage;
import org.bson.BinaryVector;
import org.bson.BsonDouble; import org.bson.BsonDouble;
import org.bson.BsonUndefined; import org.bson.BsonUndefined;
import org.bson.types.Binary; import org.bson.types.Binary;
@ -81,6 +50,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.aop.framework.ProxyFactory; import org.springframework.aop.framework.ProxyFactory;
import org.springframework.beans.ConversionNotSupportedException; import org.springframework.beans.ConversionNotSupportedException;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@ -3380,18 +3350,6 @@ class MappingMongoConverterUnitTests {
assertThat(withVector.embeddings.toDoubleArray()).contains(1.1d, 2.2d, 3.3d); assertThat(withVector.embeddings.toDoubleArray()).contains(1.1d, 2.2d, 3.3d);
} }
@Test // GH-4706
void mapsByteArrayAsVectorWhenAnnotatedWithFieldTargetType() {
WithExplicitTargetTypes source = new WithExplicitTargetTypes();
source.asVector = new byte[] { 0, 1, 2 };
org.bson.Document target = new org.bson.Document();
converter.write(source, target);
assertThatNoException().isThrownBy(() -> target.get("asVector", BinaryVector.class));
}
@Test // GH-4706 @Test // GH-4706
void writesByteArrayAsIsIfNoFieldInstructionsGiven() { void writesByteArrayAsIsIfNoFieldInstructionsGiven() {
@ -4070,9 +4028,6 @@ class MappingMongoConverterUnitTests {
@Field(targetType = FieldType.OBJECT_ID) // @Field(targetType = FieldType.OBJECT_ID) //
Date dateAsObjectId; Date dateAsObjectId;
@Field(targetType = FieldType.VECTOR) //
byte[] asVector;
} }
static class WrapperAroundWithUnwrapped { static class WrapperAroundWithUnwrapped {

40
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java

@ -15,8 +15,8 @@
*/ */
package org.springframework.data.mongodb.core.index; package org.springframework.data.mongodb.core.index;
import static org.assertj.core.api.Assertions.assertThatRuntimeException; import static org.assertj.core.api.Assertions.*;
import static org.awaitility.Awaitility.await; import static org.awaitility.Awaitility.*;
import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
import java.util.List; import java.util.List;
@ -27,6 +27,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction;
import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.mapping.Field;
@ -34,6 +35,7 @@ import org.springframework.data.mongodb.test.util.AtlasContainer;
import org.springframework.data.mongodb.test.util.MongoTestTemplate; import org.springframework.data.mongodb.test.util.MongoTestTemplate;
import org.springframework.data.mongodb.test.util.MongoTestUtils; import org.springframework.data.mongodb.test.util.MongoTestUtils;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.junit.jupiter.Testcontainers;
@ -49,7 +51,7 @@ import com.mongodb.client.AggregateIterable;
@Testcontainers(disabledWithoutDocker = true) @Testcontainers(disabledWithoutDocker = true)
class VectorIndexIntegrationTests { class VectorIndexIntegrationTests {
private static @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch();
MongoTestTemplate template = new MongoTestTemplate(cfg -> { MongoTestTemplate template = new MongoTestTemplate(cfg -> {
cfg.configureDatabaseFactory(ctx -> { cfg.configureDatabaseFactory(ctx -> {
@ -82,7 +84,7 @@ class VectorIndexIntegrationTests {
VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding",
builder -> builder.dimensions(1536).similarity(similarityFunction)); builder -> builder.dimensions(1536).similarity(similarityFunction));
indexOps.ensureIndex(idx); indexOps.createIndex(idx);
await().untilAsserted(() -> { await().untilAsserted(() -> {
Document raw = readRawIndexInfo(idx.getName()); Document raw = readRawIndexInfo(idx.getName());
@ -101,7 +103,7 @@ class VectorIndexIntegrationTests {
VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding",
builder -> builder.dimensions(1536).similarity("cosine")); builder -> builder.dimensions(1536).similarity("cosine"));
indexOps.ensureIndex(idx); indexOps.createIndex(idx);
template.awaitIndexCreation(Movie.class, idx.getName()); template.awaitIndexCreation(Movie.class, idx.getName());
@ -111,7 +113,7 @@ class VectorIndexIntegrationTests {
} }
@Test // GH-4706 @Test // GH-4706
void statusChanges() { void statusChanges() throws InterruptedException {
String indexName = "vector_index"; String indexName = "vector_index";
assertThat(indexOps.status(indexName)).isEqualTo(SearchIndexStatus.DOES_NOT_EXIST); assertThat(indexOps.status(indexName)).isEqualTo(SearchIndexStatus.DOES_NOT_EXIST);
@ -119,14 +121,17 @@ class VectorIndexIntegrationTests {
VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding",
builder -> builder.dimensions(1536).similarity("cosine")); builder -> builder.dimensions(1536).similarity("cosine"));
indexOps.ensureIndex(idx); indexOps.createIndex(idx);
// without synchronization, the container might crash.
Thread.sleep(500);
assertThat(indexOps.status(indexName)).isIn(SearchIndexStatus.PENDING, SearchIndexStatus.BUILDING, assertThat(indexOps.status(indexName)).isIn(SearchIndexStatus.PENDING, SearchIndexStatus.BUILDING,
SearchIndexStatus.READY); SearchIndexStatus.READY);
} }
@Test // GH-4706 @Test // GH-4706
void exists() { void exists() throws InterruptedException {
String indexName = "vector_index"; String indexName = "vector_index";
assertThat(indexOps.exists(indexName)).isFalse(); assertThat(indexOps.exists(indexName)).isFalse();
@ -134,19 +139,25 @@ class VectorIndexIntegrationTests {
VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding",
builder -> builder.dimensions(1536).similarity("cosine")); builder -> builder.dimensions(1536).similarity("cosine"));
indexOps.ensureIndex(idx); indexOps.createIndex(idx);
// without synchronization, the container might crash.
Thread.sleep(500);
assertThat(indexOps.exists(indexName)).isTrue(); assertThat(indexOps.exists(indexName)).isTrue();
} }
@Test // GH-4706 @Test // GH-4706
void updatesVectorIndex() { void updatesVectorIndex() throws InterruptedException {
String indexName = "vector_index"; String indexName = "vector_index";
VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding",
builder -> builder.dimensions(1536).similarity("cosine")); builder -> builder.dimensions(1536).similarity("cosine"));
indexOps.ensureIndex(idx); indexOps.createIndex(idx);
// without synchronization, the container might crash.
Thread.sleep(500);
await().untilAsserted(() -> { await().untilAsserted(() -> {
Document raw = readRawIndexInfo(idx.getName()); Document raw = readRawIndexInfo(idx.getName());
@ -166,13 +177,16 @@ class VectorIndexIntegrationTests {
} }
@Test // GH-4706 @Test // GH-4706
void createsVectorIndexWithFilters() { void createsVectorIndexWithFilters() throws InterruptedException {
VectorIndex idx = new VectorIndex("vector_index") VectorIndex idx = new VectorIndex("vector_index")
.addVector("plotEmbedding", builder -> builder.dimensions(1536).cosine()).addFilter("description") .addVector("plotEmbedding", builder -> builder.dimensions(1536).cosine()).addFilter("description")
.addFilter("year"); .addFilter("year");
indexOps.ensureIndex(idx); indexOps.createIndex(idx);
// without synchronization, the container might crash.
Thread.sleep(500);
await().untilAsserted(() -> { await().untilAsserted(() -> {
Document raw = readRawIndexInfo(idx.getName()); Document raw = readRawIndexInfo(idx.getName());

86
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java

@ -15,96 +15,44 @@
*/ */
package org.springframework.data.mongodb.test.util; package org.springframework.data.mongodb.test.util;
import java.util.List;
import org.bson.Document;
import org.springframework.core.env.StandardEnvironment; import org.springframework.core.env.StandardEnvironment;
import org.springframework.data.util.Lazy;
import org.springframework.util.StringUtils;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.DockerHealthcheckWaitStrategy;
import org.testcontainers.containers.wait.strategy.WaitStrategy;
import org.testcontainers.utility.DockerImageName;
import com.mongodb.ConnectionString; import org.testcontainers.mongodb.MongoDBAtlasLocalContainer;
import com.mongodb.client.MongoClient; import org.testcontainers.utility.DockerImageName;
import com.mongodb.client.MongoCollection;
/** /**
* Extension to MongoDBAtlasLocalContainer.
*
* @author Christoph Strobl * @author Christoph Strobl
*/ */
public class AtlasContainer extends GenericContainer<AtlasContainer> { public class AtlasContainer extends MongoDBAtlasLocalContainer {
private static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("mongodb/mongodb-atlas-local"); private static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("mongodb/mongodb-atlas-local");
private static final String DEFAULT_TAG = "latest"; private static final String DEFAULT_TAG = "8.0.0";
private static final String MONGODB_DATABASE_NAME_DEFAULT = "test"; private static final String LATEST = "latest";
private static final String READY_DB = "__db_ready_check";
private final Lazy<MongoClient> client; private AtlasContainer(String dockerImageName) {
super(DockerImageName.parse(dockerImageName));
}
private AtlasContainer(DockerImageName dockerImageName) {
super(dockerImageName);
}
public static AtlasContainer bestMatch() { public static AtlasContainer bestMatch() {
return tagged(new StandardEnvironment().getProperty("mongodb.atlas.version", DEFAULT_TAG)); return tagged(new StandardEnvironment().getProperty("mongodb.atlas.version", DEFAULT_TAG));
} }
public static AtlasContainer latest() { public static AtlasContainer latest() {
return tagged(DEFAULT_TAG); return tagged(LATEST);
} }
public static AtlasContainer version8() { public static AtlasContainer version8() {
return tagged("8.0.0"); return tagged(DEFAULT_TAG);
} }
public static AtlasContainer tagged(String tag) { public static AtlasContainer tagged(String tag) {
return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag)); return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag));
} }
public AtlasContainer(String dockerImageName) {
this(DockerImageName.parse(dockerImageName));
}
public AtlasContainer(DockerImageName dockerImageName) {
super(dockerImageName);
dockerImageName.assertCompatibleWith(DEFAULT_IMAGE_NAME);
setExposedPorts(List.of(27017));
client = Lazy.of(() -> MongoTestUtils.client(new ConnectionString(getConnectionString())));
}
public String getConnectionString() {
return getConnectionString(MONGODB_DATABASE_NAME_DEFAULT);
}
/**
* Gets a connection string url.
*
* @return a connection url pointing to a mongodb instance
*/
public String getConnectionString(String database) {
return String.format("mongodb://%s:%d/%s?directConnection=true", getHost(), getMappedPort(27017),
StringUtils.hasText(database) ? database : MONGODB_DATABASE_NAME_DEFAULT);
}
@Override
public boolean isHealthy() {
MongoClient mongoClient = client.get();
MongoCollection<Document> ready = MongoTestUtils.createOrReplaceCollection(READY_DB, "ready", mongoClient);
boolean isReady = false;
try {
ready.aggregate(List.of(new Document("$listSearchIndexes", new Document()))).first();
isReady = true;
} catch (Exception e) {
// ok so the search service is not ready yet - sigh
}
if (isReady) {
mongoClient.getDatabase(READY_DB).drop();
mongoClient.close();
}
return isReady;
}
@Override
protected WaitStrategy getWaitStrategy() {
return new DockerHealthcheckWaitStrategy();
}
} }

11
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java

@ -63,14 +63,11 @@ public class MongoTestTemplate extends MongoTemplate {
public MongoTestTemplate(Consumer<MongoTestTemplateConfiguration> cfg) { public MongoTestTemplate(Consumer<MongoTestTemplateConfiguration> cfg) {
this(new Supplier<MongoTestTemplateConfiguration>() { this(() -> {
@Override
public MongoTestTemplateConfiguration get() {
MongoTestTemplateConfiguration config = new MongoTestTemplateConfiguration(); MongoTestTemplateConfiguration config = new MongoTestTemplateConfiguration();
cfg.accept(config); cfg.accept(config);
return config; return config;
}
}); });
} }
@ -115,7 +112,7 @@ public class MongoTestTemplate extends MongoTemplate {
} }
public void flush(Class<?>... entities) { public void flush(Class<?>... entities) {
flush(Arrays.asList(entities).stream().map(this::getCollectionName).collect(Collectors.toList())); flush(Arrays.stream(entities).map(this::getCollectionName).collect(Collectors.toList()));
} }
public void flush(String... collections) { public void flush(String... collections) {
@ -124,7 +121,7 @@ public class MongoTestTemplate extends MongoTemplate {
public void flush(Object... objects) { public void flush(Object... objects) {
flush(Arrays.asList(objects).stream().map(it -> { flush(Arrays.stream(objects).map(it -> {
if (it instanceof String) { if (it instanceof String) {
return (String) it; return (String) it;
@ -167,7 +164,7 @@ public class MongoTestTemplate extends MongoTemplate {
Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> { Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> {
ArrayList<Document> execute = this.execute(collectionName, List<Document> execute = this.execute(collectionName,
coll -> coll coll -> coll
.aggregate(List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted(indexName)))) .aggregate(List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted(indexName))))
.into(new ArrayList<>())); .into(new ArrayList<>()));

Loading…
Cancel
Save