diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java index 18991c176..d5285701a 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java @@ -167,8 +167,8 @@ public class VectorSearchTests { template.searchIndexOps(WithVectorFields.class).createIndex(rawIndex); template.searchIndexOps(WithVectorFields.class).createIndex(wrapperIndex); - template.awaitIndexCreation(WithVectorFields.class, rawIndex.getName()); - template.awaitIndexCreation(WithVectorFields.class, wrapperIndex.getName()); + template.awaitSearchIndexCreation(WithVectorFields.class, rawIndex.getName()); + template.awaitSearchIndexCreation(WithVectorFields.class, wrapperIndex.getName()); } private static void assertScoreIsDecreasing(Iterable documents) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java index 387f075cb..a0719eb46 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java @@ -15,10 +15,11 @@ */ package org.springframework.data.mongodb.core.index; -import static org.assertj.core.api.Assertions.*; -import static org.awaitility.Awaitility.*; +import static org.assertj.core.api.Assertions.assertThatRuntimeException; +import static org.awaitility.Awaitility.await; import static org.springframework.data.mongodb.test.util.Assertions.assertThat; +import java.time.Duration; import java.util.List; import org.bson.Document; @@ -26,16 +27,17 @@ import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; - import org.springframework.data.annotation.Id; import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.test.util.AtlasContainer; +import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; +import org.springframework.data.mongodb.test.util.MongoServerCondition; import org.springframework.data.mongodb.test.util.MongoTestTemplate; import org.springframework.data.mongodb.test.util.MongoTestUtils; - import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -48,6 +50,7 @@ import com.mongodb.client.AggregateIterable; * @author Christoph Strobl * @author Mark Paluch */ +@ExtendWith(MongoServerCondition.class) @Testcontainers(disabledWithoutDocker = true) class VectorIndexIntegrationTests { @@ -66,19 +69,22 @@ class VectorIndexIntegrationTests { @BeforeEach void init() { - template.createCollection(Movie.class); + + template.createCollectionIfNotExists(Movie.class); indexOps = template.searchIndexOps(Movie.class); } @AfterEach void cleanup() { + template.flush(Movie.class); template.searchIndexOps(Movie.class).dropAllIndexes(); - template.dropCollection(Movie.class); + template.awaitNoSearchIndexAvailable(Movie.class, Duration.ofSeconds(30)); } @ParameterizedTest // GH-4706 @ValueSource(strings = { "euclidean", "cosine", "dotProduct" }) + @EnableIfVectorSearchAvailable(collection = Movie.class) void createsSimpleVectorIndex(String similarityFunction) { VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", @@ -98,6 +104,7 @@ class VectorIndexIntegrationTests { } @Test // GH-4706 + @EnableIfVectorSearchAvailable(collection = Movie.class) void dropIndex() { VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", @@ -105,7 +112,7 @@ class VectorIndexIntegrationTests { indexOps.createIndex(idx); - template.awaitIndexCreation(Movie.class, idx.getName()); + template.awaitSearchIndexCreation(Movie.class, idx.getName()); indexOps.dropIndex(idx.getName()); @@ -113,6 +120,7 @@ class VectorIndexIntegrationTests { } @Test // GH-4706 + @EnableIfVectorSearchAvailable(collection = Movie.class) void statusChanges() throws InterruptedException { String indexName = "vector_index"; @@ -131,6 +139,7 @@ class VectorIndexIntegrationTests { } @Test // GH-4706 + @EnableIfVectorSearchAvailable(collection = Movie.class) void exists() throws InterruptedException { String indexName = "vector_index"; @@ -148,6 +157,7 @@ class VectorIndexIntegrationTests { } @Test // GH-4706 + @EnableIfVectorSearchAvailable(collection = Movie.class) void updatesVectorIndex() throws InterruptedException { String indexName = "vector_index"; @@ -177,6 +187,7 @@ class VectorIndexIntegrationTests { } @Test // GH-4706 + @EnableIfVectorSearchAvailable(collection = Movie.class) void createsVectorIndexWithFilters() throws InterruptedException { VectorIndex idx = new VectorIndex("vector_index") diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java index 14a4749c8..15fe22cf1 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java @@ -167,9 +167,9 @@ public class ReactiveVectorSearchTests { template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex); template.searchIndexOps(WithVectorFields.class).createIndex(inner); - template.awaitIndexCreation(WithVectorFields.class, cosIndex.getName()); - template.awaitIndexCreation(WithVectorFields.class, euclideanIndex.getName()); - template.awaitIndexCreation(WithVectorFields.class, inner.getName()); + template.awaitSearchIndexCreation(WithVectorFields.class, cosIndex.getName()); + template.awaitSearchIndexCreation(WithVectorFields.class, euclideanIndex.getName()); + template.awaitSearchIndexCreation(WithVectorFields.class, inner.getName()); } interface ReactiveVectorSearchRepository extends CrudRepository { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java index a224481da..bd9f6165f 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java @@ -211,9 +211,9 @@ public class VectorSearchTests { template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex); template.searchIndexOps(WithVectorFields.class).createIndex(inner); - template.awaitIndexCreation(WithVectorFields.class, cosIndex.getName()); - template.awaitIndexCreation(WithVectorFields.class, euclideanIndex.getName()); - template.awaitIndexCreation(WithVectorFields.class, inner.getName()); + template.awaitSearchIndexCreation(WithVectorFields.class, cosIndex.getName()); + template.awaitSearchIndexCreation(WithVectorFields.class, euclideanIndex.getName()); + template.awaitSearchIndexCreation(WithVectorFields.class, inner.getName()); } interface VectorSearchRepository extends CrudRepository { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java index c3a97a03b..71fecd29b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java @@ -16,12 +16,14 @@ package org.springframework.data.mongodb.test.util; import org.springframework.core.env.StandardEnvironment; - import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; import org.testcontainers.utility.DockerImageName; +import com.github.dockerjava.api.command.InspectContainerResponse; + /** - * Extension to MongoDBAtlasLocalContainer. + * Extension to {@link MongoDBAtlasLocalContainer}. Registers mapped host an port as system properties + * ({@link #ATLAS_HOST}, {@link #ATLAS_PORT}). * * @author Christoph Strobl */ @@ -31,6 +33,9 @@ public class AtlasContainer extends MongoDBAtlasLocalContainer { private static final String DEFAULT_TAG = "8.0.0"; private static final String LATEST = "latest"; + public static final String ATLAS_HOST = "docker.mongodb.atlas.host"; + public static final String ATLAS_PORT = "docker.mongodb.atlas.port"; + private AtlasContainer(String dockerImageName) { super(DockerImageName.parse(dockerImageName)); } @@ -55,4 +60,20 @@ public class AtlasContainer extends MongoDBAtlasLocalContainer { return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag)); } + @Override + protected void containerIsStarted(InspectContainerResponse containerInfo) { + + super.containerIsStarted(containerInfo); + + System.setProperty(ATLAS_HOST, getHost()); + System.setProperty(ATLAS_PORT, getMappedPort(27017).toString()); + } + + @Override + protected void containerIsStopping(InspectContainerResponse containerInfo) { + + System.clearProperty(ATLAS_HOST); + System.clearProperty(ATLAS_PORT); + super.containerIsStopping(containerInfo); + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java index da008d9ee..c81e197fe 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java @@ -25,13 +25,30 @@ import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.extension.ExtendWith; /** + * {@link EnableIfVectorSearchAvailable} indicates a specific method can only be run in an environment that has a search + * server available. This means that not only the mongodb instance needs to have a + * {@literal searchIndexManagementHostAndPort} configured, but also that the search index sever is actually up and + * running, responding to a {@literal $listSearchIndexes} aggregation. + * * @author Christoph Strobl + * @since 5.0 + * @see Tag */ -@Target({ ElementType.TYPE, ElementType.METHOD }) +@Target({ ElementType.METHOD }) @Retention(RetentionPolicy.RUNTIME) @Documented @Tag("vector-search") @ExtendWith(MongoServerCondition.class) public @interface EnableIfVectorSearchAvailable { + /** + * @return the name of the collection used to run the {@literal $listSearchIndexes} aggregation. + */ + String collectionName() default ""; + + /** + * @return the type for resolving the name of the collection used to run the {@literal $listSearchIndexes} + * aggregation. The {@link #collectionName()} has precedence over the type. + */ + Class collection() default Object.class; } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoExtensions.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoExtensions.java index c90f7e999..864bb6aa5 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoExtensions.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoExtensions.java @@ -31,7 +31,7 @@ class MongoExtensions { static final String REACTIVE_REPLSET_KEY = "mongo.client.replset.reactive"; } - static class Termplate { + static class Template { static final Namespace NAMESPACE = Namespace.create(MongoTemplateExtension.class); static final String SYNC = "mongo.template.sync"; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java index d811e0a1e..35ca65c30 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java @@ -15,12 +15,21 @@ */ package org.springframework.data.mongodb.test.util; +import java.time.Duration; + import org.junit.jupiter.api.extension.ConditionEvaluationResult; import org.junit.jupiter.api.extension.ExecutionCondition; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.ExtensionContext.Namespace; import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.data.mongodb.MongoCollectionUtils; import org.springframework.data.util.Version; +import org.springframework.util.NumberUtils; +import org.springframework.util.StringUtils; +import org.testcontainers.shaded.org.awaitility.Awaitility; + +import com.mongodb.Function; +import com.mongodb.client.MongoClient; /** * @author Christoph Strobl @@ -42,10 +51,13 @@ public class MongoServerCondition implements ExecutionCondition { } } - if(context.getTags().contains("vector-search")) { - if(!atlasEnvironment(context)) { + if (context.getTags().contains("vector-search")) { + if (!atlasEnvironment(context)) { return ConditionEvaluationResult.disabled("Disabled for servers not supporting Vector Search."); } + if (!isSearchIndexAvailable(context)) { + return ConditionEvaluationResult.disabled("Search index unavailable."); + } } if (context.getTags().contains("version-specific") && context.getElement().isPresent()) { @@ -90,8 +102,55 @@ public class MongoServerCondition implements ExecutionCondition { Version.class); } + private boolean isSearchIndexAvailable(ExtensionContext context) { + + EnableIfVectorSearchAvailable vectorSearchAvailable = AnnotatedElementUtils + .findMergedAnnotation(context.getElement().get(), EnableIfVectorSearchAvailable.class); + + if (vectorSearchAvailable == null) { + return true; + } + + String collectionName = StringUtils.hasText(vectorSearchAvailable.collectionName()) + ? vectorSearchAvailable.collectionName() + : MongoCollectionUtils.getPreferredCollectionName(vectorSearchAvailable.collection()); + + return context.getStore(NAMESPACE).getOrComputeIfAbsent("search-index-%s-available".formatted(collectionName), + (key) -> { + try { + doWithClient(client -> { + Awaitility.await().atMost(Duration.ofSeconds(60)).pollInterval(Duration.ofMillis(200)).until(() -> { + return MongoTestUtils.isSearchIndexReady(client, null, collectionName); + }); + return "done waiting for search index"; + }); + } catch (Exception e) { + return false; + } + return true; + }, Boolean.class); + + } + private boolean atlasEnvironment(ExtensionContext context) { - return context.getStore(NAMESPACE).getOrComputeIfAbsent(Version.class, (key) -> MongoTestUtils.isVectorSearchEnabled(), - Boolean.class); + + return context.getStore(NAMESPACE).getOrComputeIfAbsent("mongodb-atlas", + (key) -> doWithClient(MongoTestUtils::isVectorSearchEnabled), Boolean.class); + } + + private T doWithClient(Function function) { + + String host = System.getProperty(AtlasContainer.ATLAS_HOST); + String port = System.getProperty(AtlasContainer.ATLAS_PORT); + + if (StringUtils.hasText(host) && StringUtils.hasText(port)) { + try (MongoClient client = MongoTestUtils.client(host, NumberUtils.parseNumber(port, Integer.class))) { + return function.apply(client); + } + } + + try (MongoClient client = MongoTestUtils.client()) { + return function.apply(client); + } } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTemplateExtension.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTemplateExtension.java index 301d1ef49..23e4a3db7 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTemplateExtension.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTemplateExtension.java @@ -33,7 +33,7 @@ import org.junit.platform.commons.util.StringUtils; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.ReactiveMongoOperations; -import org.springframework.data.mongodb.test.util.MongoExtensions.Termplate; +import org.springframework.data.mongodb.test.util.MongoExtensions.Template; import org.springframework.data.util.ParsingUtils; import org.springframework.util.ClassUtils; @@ -41,7 +41,7 @@ import org.springframework.util.ClassUtils; * JUnit {@link Extension} providing parameter resolution for synchronous and reactive MongoDB Template API objects. * * @author Christoph Strobl - * @see Template + * @see org.springframework.data.mongodb.test.util.Template * @see MongoTestTemplate * @see ReactiveMongoTestTemplate */ @@ -65,32 +65,32 @@ public class MongoTemplateExtension extends MongoClientExtension implements Test @Override public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException { - return super.supportsParameter(parameterContext, extensionContext) || parameterContext.isAnnotated(Template.class); + return super.supportsParameter(parameterContext, extensionContext) || parameterContext.isAnnotated(org.springframework.data.mongodb.test.util.Template.class); } @Override public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException { - if (parameterContext.getParameter().getAnnotation(Template.class) == null) { + if (parameterContext.getParameter().getAnnotation(org.springframework.data.mongodb.test.util.Template.class) == null) { return super.resolveParameter(parameterContext, extensionContext); } Class parameterType = parameterContext.getParameter().getType(); - return getMongoTemplate(parameterType, parameterContext.getParameter().getAnnotation(Template.class), + return getMongoTemplate(parameterType, parameterContext.getParameter().getAnnotation(org.springframework.data.mongodb.test.util.Template.class), extensionContext); } private void injectFields(ExtensionContext context, Object testInstance, Predicate predicate) { - AnnotationUtils.findAnnotatedFields(context.getRequiredTestClass(), Template.class, predicate).forEach(field -> { + AnnotationUtils.findAnnotatedFields(context.getRequiredTestClass(), org.springframework.data.mongodb.test.util.Template.class, predicate).forEach(field -> { assertValidFieldCandidate(field); try { ReflectionUtils.makeAccessible(field).set(testInstance, - getMongoTemplate(field.getType(), field.getAnnotation(Template.class), context)); + getMongoTemplate(field.getType(), field.getAnnotation(org.springframework.data.mongodb.test.util.Template.class), context)); } catch (Throwable t) { ExceptionUtils.throwAsUncheckedException(t); } @@ -107,14 +107,14 @@ public class MongoTemplateExtension extends MongoClientExtension implements Test if (!ClassUtils.isAssignable(MongoOperations.class, type) && !ClassUtils.isAssignable(ReactiveMongoOperations.class, type)) { throw new ExtensionConfigurationException( - String.format("Can only resolve @%s %s of type %s or %s but was: %s", Template.class.getSimpleName(), target, + String.format("Can only resolve @%s %s of type %s or %s but was: %s", org.springframework.data.mongodb.test.util.Template.class.getSimpleName(), target, MongoOperations.class.getName(), ReactiveMongoOperations.class.getName(), type.getName())); } } - private Object getMongoTemplate(Class type, Template options, ExtensionContext extensionContext) { + private Object getMongoTemplate(Class type, org.springframework.data.mongodb.test.util.Template options, ExtensionContext extensionContext) { - Store templateStore = extensionContext.getStore(MongoExtensions.Termplate.NAMESPACE); + Store templateStore = extensionContext.getStore(Template.NAMESPACE); boolean replSetClient = holdsReplSetClient(extensionContext) || options.replicaSet(); @@ -126,7 +126,7 @@ public class MongoTemplateExtension extends MongoClientExtension implements Test if (ClassUtils.isAssignable(MongoOperations.class, type)) { - String key = Termplate.SYNC + "-" + dbName; + String key = Template.SYNC + "-" + dbName; return templateStore.getOrComputeIfAbsent(key, it -> { com.mongodb.client.MongoClient client = (com.mongodb.client.MongoClient) getMongoClient( @@ -137,7 +137,7 @@ public class MongoTemplateExtension extends MongoClientExtension implements Test if (ClassUtils.isAssignable(ReactiveMongoOperations.class, type)) { - String key = Termplate.REACTIVE + "-" + dbName; + String key = Template.REACTIVE + "-" + dbName; return templateStore.getOrComputeIfAbsent(key, it -> { com.mongodb.reactivestreams.client.MongoClient client = (com.mongodb.reactivestreams.client.MongoClient) getMongoClient( diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java index 771c17c4a..4e619c609 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java @@ -23,7 +23,6 @@ import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; -import com.mongodb.client.MongoClients; import org.bson.Document; import org.springframework.context.ApplicationContext; import org.springframework.data.mapping.callback.EntityCallbacks; @@ -32,8 +31,12 @@ import org.springframework.data.mongodb.core.MongoTemplate; import org.testcontainers.shaded.org.awaitility.Awaitility; import com.mongodb.MongoWriteException; +import com.mongodb.ReadPreference; +import com.mongodb.WriteConcern; import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; /** * A {@link MongoTemplate} with configuration hooks and extension suitable for tests. @@ -141,6 +144,21 @@ public class MongoTestTemplate extends MongoTemplate { }).collect(Collectors.toList())); } + public void createCollectionIfNotExists(Class type) { + createCollectionIfNotExists(getCollectionName(type)); + } + + public void createCollectionIfNotExists(String collectionName) { + + MongoDatabase database = getDb().withWriteConcern(WriteConcern.MAJORITY) + .withReadPreference(ReadPreference.primary()); + + boolean collectionExists = database.listCollections().filter(new Document("name", collectionName)).first() != null; + if (!collectionExists) { + createCollection(collectionName); + } + } + public void dropDatabase() { getDb().drop(); } @@ -164,11 +182,11 @@ public class MongoTestTemplate extends MongoTemplate { })); } - public void awaitIndexCreation(Class type, String indexName) { - awaitIndexCreation(getCollectionName(type), indexName, Duration.ofSeconds(10)); + public void awaitSearchIndexCreation(Class type, String indexName) { + awaitSearchIndexCreation(getCollectionName(type), indexName, Duration.ofSeconds(30)); } - public void awaitIndexCreation(String collectionName, String indexName, Duration timeout) { + public void awaitSearchIndexCreation(String collectionName, String indexName, Duration timeout) { Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> { @@ -184,4 +202,35 @@ public class MongoTestTemplate extends MongoTemplate { return false; }); } + + public void awaitIndexDeletion(String collectionName, String indexName, Duration timeout) { + + Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> { + + List execute = this.execute(collectionName, + coll -> coll + .aggregate(List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted(indexName)))) + .into(new ArrayList<>())); + for (Document doc : execute) { + if (doc.getString("name").equals(indexName)) { + return false; + } + } + return true; + }); + } + + public void awaitNoSearchIndexAvailable(String collectionName, Duration timeout) { + + Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> { + + return this.execute(collectionName, coll -> coll.aggregate(List.of(Document.parse("{'$listSearchIndexes': {}}"))) + .into(new ArrayList<>()).isEmpty()); + + }); + } + + public void awaitNoSearchIndexAvailable(Class type, Duration timeout) { + awaitNoSearchIndexAvailable(getCollectionName(type), timeout); + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java index f88caf80d..742fd5b44 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java @@ -15,12 +15,15 @@ */ package org.springframework.data.mongodb.test.util; +import org.jspecify.annotations.Nullable; +import org.springframework.util.StringUtils; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import reactor.util.retry.Retry; import java.time.Duration; import java.util.List; +import java.util.concurrent.TimeUnit; import org.bson.Document; import org.springframework.core.env.Environment; @@ -30,6 +33,7 @@ import org.springframework.data.util.Version; import org.springframework.util.ObjectUtils; import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; import com.mongodb.ReadPreference; import com.mongodb.WriteConcern; import com.mongodb.client.MongoClient; @@ -68,6 +72,10 @@ public class MongoTestUtils { } public static MongoClient client(ConnectionString connectionString) { + MongoClientSettings settings = MongoClientSettings.builder().applyConnectionString(connectionString) + .applyToSocketSettings(builder -> { + builder.connectTimeout(120, TimeUnit.SECONDS); + }).build(); return com.mongodb.client.MongoClients.create(connectionString, SpringDataMongoDB.driverInformation()); } @@ -176,11 +184,10 @@ public class MongoTestUtils { * @param collectionName must not be {@literal null}. * @param client must not be {@literal null}. */ - public static void dropCollectionNow(String dbName, String collectionName, - com.mongodb.client.MongoClient client) { + public static void dropCollectionNow(String dbName, String collectionName, com.mongodb.client.MongoClient client) { - com.mongodb.client.MongoDatabase database = client.getDatabase(dbName) - .withWriteConcern(WriteConcern.MAJORITY).withReadPreference(ReadPreference.primary()); + com.mongodb.client.MongoDatabase database = client.getDatabase(dbName).withWriteConcern(WriteConcern.MAJORITY) + .withReadPreference(ReadPreference.primary()); database.getCollection(collectionName).drop(); } @@ -205,11 +212,10 @@ public class MongoTestUtils { .verifyComplete(); } - public static void flushCollection(String dbName, String collectionName, - com.mongodb.client.MongoClient client) { + public static void flushCollection(String dbName, String collectionName, com.mongodb.client.MongoClient client) { - com.mongodb.client.MongoDatabase database = client.getDatabase(dbName) - .withWriteConcern(WriteConcern.MAJORITY).withReadPreference(ReadPreference.primary()); + com.mongodb.client.MongoDatabase database = client.getDatabase(dbName).withWriteConcern(WriteConcern.MAJORITY) + .withReadPreference(ReadPreference.primary()); database.getCollection(collectionName).deleteMany(new Document()); } @@ -267,19 +273,36 @@ public class MongoTestUtils { @SuppressWarnings("unchecked") public static boolean isVectorSearchEnabled() { try (MongoClient client = MongoTestUtils.client()) { + return isVectorSearchEnabled(client); + } + } + public static boolean isVectorSearchEnabled(MongoClient client) { + try { return client.getDatabase("admin").runCommand(new Document("getCmdLineOpts", "1")).get("argv", List.class) - .stream().anyMatch(it -> { - if(it instanceof String cfgString) { - return cfgString.startsWith("searchIndexManagementHostAndPort"); - } - return false; - }); + .stream().anyMatch(it -> { + if (it instanceof String cfgString) { + return cfgString.startsWith("searchIndexManagementHostAndPort"); + } + return false; + }); } catch (Exception e) { return false; } } + public static boolean isSearchIndexReady(MongoClient client, @Nullable String database, String collectionName) { + + try { + MongoCollection collection = client.getDatabase(StringUtils.hasText(database) ? database : "test").getCollection(collectionName); + collection.aggregate(List.of(new Document("$listSearchIndexes", new Document()))); + } catch (Exception e) { + return false; + } + return true; + + } + public static Duration getTimeout() { return ObjectUtils.nullSafeEquals("jenkins", ENV.getProperty("user.name")) ? Duration.ofMillis(100) @@ -297,10 +320,11 @@ public class MongoTestUtils { public static CollectionInfo readCollectionInfo(MongoDatabase db, String collectionName) { - List list = db.runCommand(new Document().append("listCollections", 1).append("filter", new Document("name", collectionName))) + List list = db + .runCommand(new Document().append("listCollections", 1).append("filter", new Document("name", collectionName))) .get("cursor", Document.class).get("firstBatch", List.class); - if(list.isEmpty()) { + if (list.isEmpty()) { throw new IllegalStateException(String.format("Collection %s not found.", collectionName)); } return CollectionInfo.from(list.get(0)); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/ReplSetClient.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/ReplSetClient.java index 8342c5b5e..ede3687f7 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/ReplSetClient.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/ReplSetClient.java @@ -21,6 +21,8 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.junit.jupiter.api.extension.ExtendWith; + /** * Marks a field or method as to be autowired by JUnit's dependency injection facilities for injection of a MongoDB * client instance connected to a replica set. Depends on {@link MongoClientExtension}. @@ -34,6 +36,7 @@ import java.lang.annotation.Target; @Target({ ElementType.FIELD, ElementType.PARAMETER }) @Retention(RetentionPolicy.RUNTIME) @Documented +@ExtendWith(MongoClientExtension.class) public @interface ReplSetClient { }