Browse Source

Introduce `@EnableIfVectorSearchAvailable` to wait and conditionally skip tests.

We now wait until a search index becomes available. If the search index doesn't come alive within 60 seconds, we skip that test (or test class).

Closes: #5013
Original pull request: #5014
pull/5016/head
Christoph Strobl 6 months ago committed by Mark Paluch
parent
commit
d43d6b062e
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 4
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java
  2. 25
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java
  3. 6
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java
  4. 6
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java
  5. 25
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java
  6. 19
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java
  7. 2
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoExtensions.java
  8. 67
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java
  9. 24
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTemplateExtension.java
  10. 57
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java
  11. 56
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java
  12. 3
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/ReplSetClient.java

4
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java

@ -167,8 +167,8 @@ public class VectorSearchTests {
template.searchIndexOps(WithVectorFields.class).createIndex(rawIndex); template.searchIndexOps(WithVectorFields.class).createIndex(rawIndex);
template.searchIndexOps(WithVectorFields.class).createIndex(wrapperIndex); template.searchIndexOps(WithVectorFields.class).createIndex(wrapperIndex);
template.awaitIndexCreation(WithVectorFields.class, rawIndex.getName()); template.awaitSearchIndexCreation(WithVectorFields.class, rawIndex.getName());
template.awaitIndexCreation(WithVectorFields.class, wrapperIndex.getName()); template.awaitSearchIndexCreation(WithVectorFields.class, wrapperIndex.getName());
} }
private static void assertScoreIsDecreasing(Iterable<Document> documents) { private static void assertScoreIsDecreasing(Iterable<Document> documents) {

25
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java

@ -15,10 +15,11 @@
*/ */
package org.springframework.data.mongodb.core.index; package org.springframework.data.mongodb.core.index;
import static org.assertj.core.api.Assertions.*; import static org.assertj.core.api.Assertions.assertThatRuntimeException;
import static org.awaitility.Awaitility.*; import static org.awaitility.Awaitility.await;
import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
import java.time.Duration;
import java.util.List; import java.util.List;
import org.bson.Document; import org.bson.Document;
@ -26,16 +27,17 @@ import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction;
import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.test.util.AtlasContainer; import org.springframework.data.mongodb.test.util.AtlasContainer;
import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable;
import org.springframework.data.mongodb.test.util.MongoServerCondition;
import org.springframework.data.mongodb.test.util.MongoTestTemplate; import org.springframework.data.mongodb.test.util.MongoTestTemplate;
import org.springframework.data.mongodb.test.util.MongoTestUtils; import org.springframework.data.mongodb.test.util.MongoTestUtils;
import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.junit.jupiter.Testcontainers;
@ -48,6 +50,7 @@ import com.mongodb.client.AggregateIterable;
* @author Christoph Strobl * @author Christoph Strobl
* @author Mark Paluch * @author Mark Paluch
*/ */
@ExtendWith(MongoServerCondition.class)
@Testcontainers(disabledWithoutDocker = true) @Testcontainers(disabledWithoutDocker = true)
class VectorIndexIntegrationTests { class VectorIndexIntegrationTests {
@ -66,19 +69,22 @@ class VectorIndexIntegrationTests {
@BeforeEach @BeforeEach
void init() { void init() {
template.createCollection(Movie.class);
template.createCollectionIfNotExists(Movie.class);
indexOps = template.searchIndexOps(Movie.class); indexOps = template.searchIndexOps(Movie.class);
} }
@AfterEach @AfterEach
void cleanup() { void cleanup() {
template.flush(Movie.class);
template.searchIndexOps(Movie.class).dropAllIndexes(); template.searchIndexOps(Movie.class).dropAllIndexes();
template.dropCollection(Movie.class); template.awaitNoSearchIndexAvailable(Movie.class, Duration.ofSeconds(30));
} }
@ParameterizedTest // GH-4706 @ParameterizedTest // GH-4706
@ValueSource(strings = { "euclidean", "cosine", "dotProduct" }) @ValueSource(strings = { "euclidean", "cosine", "dotProduct" })
@EnableIfVectorSearchAvailable(collection = Movie.class)
void createsSimpleVectorIndex(String similarityFunction) { void createsSimpleVectorIndex(String similarityFunction) {
VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding",
@ -98,6 +104,7 @@ class VectorIndexIntegrationTests {
} }
@Test // GH-4706 @Test // GH-4706
@EnableIfVectorSearchAvailable(collection = Movie.class)
void dropIndex() { void dropIndex() {
VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding",
@ -105,7 +112,7 @@ class VectorIndexIntegrationTests {
indexOps.createIndex(idx); indexOps.createIndex(idx);
template.awaitIndexCreation(Movie.class, idx.getName()); template.awaitSearchIndexCreation(Movie.class, idx.getName());
indexOps.dropIndex(idx.getName()); indexOps.dropIndex(idx.getName());
@ -113,6 +120,7 @@ class VectorIndexIntegrationTests {
} }
@Test // GH-4706 @Test // GH-4706
@EnableIfVectorSearchAvailable(collection = Movie.class)
void statusChanges() throws InterruptedException { void statusChanges() throws InterruptedException {
String indexName = "vector_index"; String indexName = "vector_index";
@ -131,6 +139,7 @@ class VectorIndexIntegrationTests {
} }
@Test // GH-4706 @Test // GH-4706
@EnableIfVectorSearchAvailable(collection = Movie.class)
void exists() throws InterruptedException { void exists() throws InterruptedException {
String indexName = "vector_index"; String indexName = "vector_index";
@ -148,6 +157,7 @@ class VectorIndexIntegrationTests {
} }
@Test // GH-4706 @Test // GH-4706
@EnableIfVectorSearchAvailable(collection = Movie.class)
void updatesVectorIndex() throws InterruptedException { void updatesVectorIndex() throws InterruptedException {
String indexName = "vector_index"; String indexName = "vector_index";
@ -177,6 +187,7 @@ class VectorIndexIntegrationTests {
} }
@Test // GH-4706 @Test // GH-4706
@EnableIfVectorSearchAvailable(collection = Movie.class)
void createsVectorIndexWithFilters() throws InterruptedException { void createsVectorIndexWithFilters() throws InterruptedException {
VectorIndex idx = new VectorIndex("vector_index") VectorIndex idx = new VectorIndex("vector_index")

6
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java

@ -167,9 +167,9 @@ public class ReactiveVectorSearchTests {
template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex);
template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex); template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex);
template.searchIndexOps(WithVectorFields.class).createIndex(inner); template.searchIndexOps(WithVectorFields.class).createIndex(inner);
template.awaitIndexCreation(WithVectorFields.class, cosIndex.getName()); template.awaitSearchIndexCreation(WithVectorFields.class, cosIndex.getName());
template.awaitIndexCreation(WithVectorFields.class, euclideanIndex.getName()); template.awaitSearchIndexCreation(WithVectorFields.class, euclideanIndex.getName());
template.awaitIndexCreation(WithVectorFields.class, inner.getName()); template.awaitSearchIndexCreation(WithVectorFields.class, inner.getName());
} }
interface ReactiveVectorSearchRepository extends CrudRepository<WithVectorFields, String> { interface ReactiveVectorSearchRepository extends CrudRepository<WithVectorFields, String> {

6
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java

@ -211,9 +211,9 @@ public class VectorSearchTests {
template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex);
template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex); template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex);
template.searchIndexOps(WithVectorFields.class).createIndex(inner); template.searchIndexOps(WithVectorFields.class).createIndex(inner);
template.awaitIndexCreation(WithVectorFields.class, cosIndex.getName()); template.awaitSearchIndexCreation(WithVectorFields.class, cosIndex.getName());
template.awaitIndexCreation(WithVectorFields.class, euclideanIndex.getName()); template.awaitSearchIndexCreation(WithVectorFields.class, euclideanIndex.getName());
template.awaitIndexCreation(WithVectorFields.class, inner.getName()); template.awaitSearchIndexCreation(WithVectorFields.class, inner.getName());
} }
interface VectorSearchRepository extends CrudRepository<WithVectorFields, String> { interface VectorSearchRepository extends CrudRepository<WithVectorFields, String> {

25
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java

@ -16,12 +16,14 @@
package org.springframework.data.mongodb.test.util; package org.springframework.data.mongodb.test.util;
import org.springframework.core.env.StandardEnvironment; import org.springframework.core.env.StandardEnvironment;
import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; import org.testcontainers.mongodb.MongoDBAtlasLocalContainer;
import org.testcontainers.utility.DockerImageName; import org.testcontainers.utility.DockerImageName;
import com.github.dockerjava.api.command.InspectContainerResponse;
/** /**
* Extension to MongoDBAtlasLocalContainer. * Extension to {@link MongoDBAtlasLocalContainer}. Registers mapped host an port as system properties
* ({@link #ATLAS_HOST}, {@link #ATLAS_PORT}).
* *
* @author Christoph Strobl * @author Christoph Strobl
*/ */
@ -31,6 +33,9 @@ public class AtlasContainer extends MongoDBAtlasLocalContainer {
private static final String DEFAULT_TAG = "8.0.0"; private static final String DEFAULT_TAG = "8.0.0";
private static final String LATEST = "latest"; private static final String LATEST = "latest";
public static final String ATLAS_HOST = "docker.mongodb.atlas.host";
public static final String ATLAS_PORT = "docker.mongodb.atlas.port";
private AtlasContainer(String dockerImageName) { private AtlasContainer(String dockerImageName) {
super(DockerImageName.parse(dockerImageName)); super(DockerImageName.parse(dockerImageName));
} }
@ -55,4 +60,20 @@ public class AtlasContainer extends MongoDBAtlasLocalContainer {
return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag)); return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag));
} }
@Override
protected void containerIsStarted(InspectContainerResponse containerInfo) {
super.containerIsStarted(containerInfo);
System.setProperty(ATLAS_HOST, getHost());
System.setProperty(ATLAS_PORT, getMappedPort(27017).toString());
}
@Override
protected void containerIsStopping(InspectContainerResponse containerInfo) {
System.clearProperty(ATLAS_HOST);
System.clearProperty(ATLAS_PORT);
super.containerIsStopping(containerInfo);
}
} }

19
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java

@ -25,13 +25,30 @@ import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
/** /**
* {@link EnableIfVectorSearchAvailable} indicates a specific method can only be run in an environment that has a search
* server available. This means that not only the mongodb instance needs to have a
* {@literal searchIndexManagementHostAndPort} configured, but also that the search index sever is actually up and
* running, responding to a {@literal $listSearchIndexes} aggregation.
*
* @author Christoph Strobl * @author Christoph Strobl
* @since 5.0
* @see Tag
*/ */
@Target({ ElementType.TYPE, ElementType.METHOD }) @Target({ ElementType.METHOD })
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Documented @Documented
@Tag("vector-search") @Tag("vector-search")
@ExtendWith(MongoServerCondition.class) @ExtendWith(MongoServerCondition.class)
public @interface EnableIfVectorSearchAvailable { public @interface EnableIfVectorSearchAvailable {
/**
* @return the name of the collection used to run the {@literal $listSearchIndexes} aggregation.
*/
String collectionName() default "";
/**
* @return the type for resolving the name of the collection used to run the {@literal $listSearchIndexes}
* aggregation. The {@link #collectionName()} has precedence over the type.
*/
Class<?> collection() default Object.class;
} }

2
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoExtensions.java

@ -31,7 +31,7 @@ class MongoExtensions {
static final String REACTIVE_REPLSET_KEY = "mongo.client.replset.reactive"; static final String REACTIVE_REPLSET_KEY = "mongo.client.replset.reactive";
} }
static class Termplate { static class Template {
static final Namespace NAMESPACE = Namespace.create(MongoTemplateExtension.class); static final Namespace NAMESPACE = Namespace.create(MongoTemplateExtension.class);
static final String SYNC = "mongo.template.sync"; static final String SYNC = "mongo.template.sync";

67
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java

@ -15,12 +15,21 @@
*/ */
package org.springframework.data.mongodb.test.util; package org.springframework.data.mongodb.test.util;
import java.time.Duration;
import org.junit.jupiter.api.extension.ConditionEvaluationResult; import org.junit.jupiter.api.extension.ConditionEvaluationResult;
import org.junit.jupiter.api.extension.ExecutionCondition; import org.junit.jupiter.api.extension.ExecutionCondition;
import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ExtensionContext.Namespace; import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.data.mongodb.MongoCollectionUtils;
import org.springframework.data.util.Version; import org.springframework.data.util.Version;
import org.springframework.util.NumberUtils;
import org.springframework.util.StringUtils;
import org.testcontainers.shaded.org.awaitility.Awaitility;
import com.mongodb.Function;
import com.mongodb.client.MongoClient;
/** /**
* @author Christoph Strobl * @author Christoph Strobl
@ -42,10 +51,13 @@ public class MongoServerCondition implements ExecutionCondition {
} }
} }
if(context.getTags().contains("vector-search")) { if (context.getTags().contains("vector-search")) {
if(!atlasEnvironment(context)) { if (!atlasEnvironment(context)) {
return ConditionEvaluationResult.disabled("Disabled for servers not supporting Vector Search."); return ConditionEvaluationResult.disabled("Disabled for servers not supporting Vector Search.");
} }
if (!isSearchIndexAvailable(context)) {
return ConditionEvaluationResult.disabled("Search index unavailable.");
}
} }
if (context.getTags().contains("version-specific") && context.getElement().isPresent()) { if (context.getTags().contains("version-specific") && context.getElement().isPresent()) {
@ -90,8 +102,55 @@ public class MongoServerCondition implements ExecutionCondition {
Version.class); Version.class);
} }
private boolean isSearchIndexAvailable(ExtensionContext context) {
EnableIfVectorSearchAvailable vectorSearchAvailable = AnnotatedElementUtils
.findMergedAnnotation(context.getElement().get(), EnableIfVectorSearchAvailable.class);
if (vectorSearchAvailable == null) {
return true;
}
String collectionName = StringUtils.hasText(vectorSearchAvailable.collectionName())
? vectorSearchAvailable.collectionName()
: MongoCollectionUtils.getPreferredCollectionName(vectorSearchAvailable.collection());
return context.getStore(NAMESPACE).getOrComputeIfAbsent("search-index-%s-available".formatted(collectionName),
(key) -> {
try {
doWithClient(client -> {
Awaitility.await().atMost(Duration.ofSeconds(60)).pollInterval(Duration.ofMillis(200)).until(() -> {
return MongoTestUtils.isSearchIndexReady(client, null, collectionName);
});
return "done waiting for search index";
});
} catch (Exception e) {
return false;
}
return true;
}, Boolean.class);
}
private boolean atlasEnvironment(ExtensionContext context) { private boolean atlasEnvironment(ExtensionContext context) {
return context.getStore(NAMESPACE).getOrComputeIfAbsent(Version.class, (key) -> MongoTestUtils.isVectorSearchEnabled(),
Boolean.class); return context.getStore(NAMESPACE).getOrComputeIfAbsent("mongodb-atlas",
(key) -> doWithClient(MongoTestUtils::isVectorSearchEnabled), Boolean.class);
}
private <T> T doWithClient(Function<MongoClient, T> function) {
String host = System.getProperty(AtlasContainer.ATLAS_HOST);
String port = System.getProperty(AtlasContainer.ATLAS_PORT);
if (StringUtils.hasText(host) && StringUtils.hasText(port)) {
try (MongoClient client = MongoTestUtils.client(host, NumberUtils.parseNumber(port, Integer.class))) {
return function.apply(client);
}
}
try (MongoClient client = MongoTestUtils.client()) {
return function.apply(client);
}
} }
} }

24
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTemplateExtension.java

@ -33,7 +33,7 @@ import org.junit.platform.commons.util.StringUtils;
import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.ReactiveMongoOperations; import org.springframework.data.mongodb.core.ReactiveMongoOperations;
import org.springframework.data.mongodb.test.util.MongoExtensions.Termplate; import org.springframework.data.mongodb.test.util.MongoExtensions.Template;
import org.springframework.data.util.ParsingUtils; import org.springframework.data.util.ParsingUtils;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
@ -41,7 +41,7 @@ import org.springframework.util.ClassUtils;
* JUnit {@link Extension} providing parameter resolution for synchronous and reactive MongoDB Template API objects. * JUnit {@link Extension} providing parameter resolution for synchronous and reactive MongoDB Template API objects.
* *
* @author Christoph Strobl * @author Christoph Strobl
* @see Template * @see org.springframework.data.mongodb.test.util.Template
* @see MongoTestTemplate * @see MongoTestTemplate
* @see ReactiveMongoTestTemplate * @see ReactiveMongoTestTemplate
*/ */
@ -65,32 +65,32 @@ public class MongoTemplateExtension extends MongoClientExtension implements Test
@Override @Override
public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException { throws ParameterResolutionException {
return super.supportsParameter(parameterContext, extensionContext) || parameterContext.isAnnotated(Template.class); return super.supportsParameter(parameterContext, extensionContext) || parameterContext.isAnnotated(org.springframework.data.mongodb.test.util.Template.class);
} }
@Override @Override
public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException { throws ParameterResolutionException {
if (parameterContext.getParameter().getAnnotation(Template.class) == null) { if (parameterContext.getParameter().getAnnotation(org.springframework.data.mongodb.test.util.Template.class) == null) {
return super.resolveParameter(parameterContext, extensionContext); return super.resolveParameter(parameterContext, extensionContext);
} }
Class<?> parameterType = parameterContext.getParameter().getType(); Class<?> parameterType = parameterContext.getParameter().getType();
return getMongoTemplate(parameterType, parameterContext.getParameter().getAnnotation(Template.class), return getMongoTemplate(parameterType, parameterContext.getParameter().getAnnotation(org.springframework.data.mongodb.test.util.Template.class),
extensionContext); extensionContext);
} }
private void injectFields(ExtensionContext context, Object testInstance, Predicate<Field> predicate) { private void injectFields(ExtensionContext context, Object testInstance, Predicate<Field> predicate) {
AnnotationUtils.findAnnotatedFields(context.getRequiredTestClass(), Template.class, predicate).forEach(field -> { AnnotationUtils.findAnnotatedFields(context.getRequiredTestClass(), org.springframework.data.mongodb.test.util.Template.class, predicate).forEach(field -> {
assertValidFieldCandidate(field); assertValidFieldCandidate(field);
try { try {
ReflectionUtils.makeAccessible(field).set(testInstance, ReflectionUtils.makeAccessible(field).set(testInstance,
getMongoTemplate(field.getType(), field.getAnnotation(Template.class), context)); getMongoTemplate(field.getType(), field.getAnnotation(org.springframework.data.mongodb.test.util.Template.class), context));
} catch (Throwable t) { } catch (Throwable t) {
ExceptionUtils.throwAsUncheckedException(t); ExceptionUtils.throwAsUncheckedException(t);
} }
@ -107,14 +107,14 @@ public class MongoTemplateExtension extends MongoClientExtension implements Test
if (!ClassUtils.isAssignable(MongoOperations.class, type) if (!ClassUtils.isAssignable(MongoOperations.class, type)
&& !ClassUtils.isAssignable(ReactiveMongoOperations.class, type)) { && !ClassUtils.isAssignable(ReactiveMongoOperations.class, type)) {
throw new ExtensionConfigurationException( throw new ExtensionConfigurationException(
String.format("Can only resolve @%s %s of type %s or %s but was: %s", Template.class.getSimpleName(), target, String.format("Can only resolve @%s %s of type %s or %s but was: %s", org.springframework.data.mongodb.test.util.Template.class.getSimpleName(), target,
MongoOperations.class.getName(), ReactiveMongoOperations.class.getName(), type.getName())); MongoOperations.class.getName(), ReactiveMongoOperations.class.getName(), type.getName()));
} }
} }
private Object getMongoTemplate(Class<?> type, Template options, ExtensionContext extensionContext) { private Object getMongoTemplate(Class<?> type, org.springframework.data.mongodb.test.util.Template options, ExtensionContext extensionContext) {
Store templateStore = extensionContext.getStore(MongoExtensions.Termplate.NAMESPACE); Store templateStore = extensionContext.getStore(Template.NAMESPACE);
boolean replSetClient = holdsReplSetClient(extensionContext) || options.replicaSet(); boolean replSetClient = holdsReplSetClient(extensionContext) || options.replicaSet();
@ -126,7 +126,7 @@ public class MongoTemplateExtension extends MongoClientExtension implements Test
if (ClassUtils.isAssignable(MongoOperations.class, type)) { if (ClassUtils.isAssignable(MongoOperations.class, type)) {
String key = Termplate.SYNC + "-" + dbName; String key = Template.SYNC + "-" + dbName;
return templateStore.getOrComputeIfAbsent(key, it -> { return templateStore.getOrComputeIfAbsent(key, it -> {
com.mongodb.client.MongoClient client = (com.mongodb.client.MongoClient) getMongoClient( com.mongodb.client.MongoClient client = (com.mongodb.client.MongoClient) getMongoClient(
@ -137,7 +137,7 @@ public class MongoTemplateExtension extends MongoClientExtension implements Test
if (ClassUtils.isAssignable(ReactiveMongoOperations.class, type)) { if (ClassUtils.isAssignable(ReactiveMongoOperations.class, type)) {
String key = Termplate.REACTIVE + "-" + dbName; String key = Template.REACTIVE + "-" + dbName;
return templateStore.getOrComputeIfAbsent(key, it -> { return templateStore.getOrComputeIfAbsent(key, it -> {
com.mongodb.reactivestreams.client.MongoClient client = (com.mongodb.reactivestreams.client.MongoClient) getMongoClient( com.mongodb.reactivestreams.client.MongoClient client = (com.mongodb.reactivestreams.client.MongoClient) getMongoClient(

57
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java

@ -23,7 +23,6 @@ import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import com.mongodb.client.MongoClients;
import org.bson.Document; import org.bson.Document;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.data.mapping.callback.EntityCallbacks; import org.springframework.data.mapping.callback.EntityCallbacks;
@ -32,8 +31,12 @@ import org.springframework.data.mongodb.core.MongoTemplate;
import org.testcontainers.shaded.org.awaitility.Awaitility; import org.testcontainers.shaded.org.awaitility.Awaitility;
import com.mongodb.MongoWriteException; import com.mongodb.MongoWriteException;
import com.mongodb.ReadPreference;
import com.mongodb.WriteConcern;
import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
/** /**
* A {@link MongoTemplate} with configuration hooks and extension suitable for tests. * A {@link MongoTemplate} with configuration hooks and extension suitable for tests.
@ -141,6 +144,21 @@ public class MongoTestTemplate extends MongoTemplate {
}).collect(Collectors.toList())); }).collect(Collectors.toList()));
} }
public void createCollectionIfNotExists(Class<?> type) {
createCollectionIfNotExists(getCollectionName(type));
}
public void createCollectionIfNotExists(String collectionName) {
MongoDatabase database = getDb().withWriteConcern(WriteConcern.MAJORITY)
.withReadPreference(ReadPreference.primary());
boolean collectionExists = database.listCollections().filter(new Document("name", collectionName)).first() != null;
if (!collectionExists) {
createCollection(collectionName);
}
}
public void dropDatabase() { public void dropDatabase() {
getDb().drop(); getDb().drop();
} }
@ -164,11 +182,11 @@ public class MongoTestTemplate extends MongoTemplate {
})); }));
} }
public void awaitIndexCreation(Class<?> type, String indexName) { public void awaitSearchIndexCreation(Class<?> type, String indexName) {
awaitIndexCreation(getCollectionName(type), indexName, Duration.ofSeconds(10)); awaitSearchIndexCreation(getCollectionName(type), indexName, Duration.ofSeconds(30));
} }
public void awaitIndexCreation(String collectionName, String indexName, Duration timeout) { public void awaitSearchIndexCreation(String collectionName, String indexName, Duration timeout) {
Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> { Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> {
@ -184,4 +202,35 @@ public class MongoTestTemplate extends MongoTemplate {
return false; return false;
}); });
} }
public void awaitIndexDeletion(String collectionName, String indexName, Duration timeout) {
Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> {
List<Document> execute = this.execute(collectionName,
coll -> coll
.aggregate(List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted(indexName))))
.into(new ArrayList<>()));
for (Document doc : execute) {
if (doc.getString("name").equals(indexName)) {
return false;
}
}
return true;
});
}
public void awaitNoSearchIndexAvailable(String collectionName, Duration timeout) {
Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> {
return this.execute(collectionName, coll -> coll.aggregate(List.of(Document.parse("{'$listSearchIndexes': {}}")))
.into(new ArrayList<>()).isEmpty());
});
}
public void awaitNoSearchIndexAvailable(Class<?> type, Duration timeout) {
awaitNoSearchIndexAvailable(getCollectionName(type), timeout);
}
} }

56
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java

@ -15,12 +15,15 @@
*/ */
package org.springframework.data.mongodb.test.util; package org.springframework.data.mongodb.test.util;
import org.jspecify.annotations.Nullable;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import reactor.util.retry.Retry; import reactor.util.retry.Retry;
import java.time.Duration; import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit;
import org.bson.Document; import org.bson.Document;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
@ -30,6 +33,7 @@ import org.springframework.data.util.Version;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
import com.mongodb.ConnectionString; import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.ReadPreference; import com.mongodb.ReadPreference;
import com.mongodb.WriteConcern; import com.mongodb.WriteConcern;
import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClient;
@ -68,6 +72,10 @@ public class MongoTestUtils {
} }
public static MongoClient client(ConnectionString connectionString) { public static MongoClient client(ConnectionString connectionString) {
MongoClientSettings settings = MongoClientSettings.builder().applyConnectionString(connectionString)
.applyToSocketSettings(builder -> {
builder.connectTimeout(120, TimeUnit.SECONDS);
}).build();
return com.mongodb.client.MongoClients.create(connectionString, SpringDataMongoDB.driverInformation()); return com.mongodb.client.MongoClients.create(connectionString, SpringDataMongoDB.driverInformation());
} }
@ -176,11 +184,10 @@ public class MongoTestUtils {
* @param collectionName must not be {@literal null}. * @param collectionName must not be {@literal null}.
* @param client must not be {@literal null}. * @param client must not be {@literal null}.
*/ */
public static void dropCollectionNow(String dbName, String collectionName, public static void dropCollectionNow(String dbName, String collectionName, com.mongodb.client.MongoClient client) {
com.mongodb.client.MongoClient client) {
com.mongodb.client.MongoDatabase database = client.getDatabase(dbName) com.mongodb.client.MongoDatabase database = client.getDatabase(dbName).withWriteConcern(WriteConcern.MAJORITY)
.withWriteConcern(WriteConcern.MAJORITY).withReadPreference(ReadPreference.primary()); .withReadPreference(ReadPreference.primary());
database.getCollection(collectionName).drop(); database.getCollection(collectionName).drop();
} }
@ -205,11 +212,10 @@ public class MongoTestUtils {
.verifyComplete(); .verifyComplete();
} }
public static void flushCollection(String dbName, String collectionName, public static void flushCollection(String dbName, String collectionName, com.mongodb.client.MongoClient client) {
com.mongodb.client.MongoClient client) {
com.mongodb.client.MongoDatabase database = client.getDatabase(dbName) com.mongodb.client.MongoDatabase database = client.getDatabase(dbName).withWriteConcern(WriteConcern.MAJORITY)
.withWriteConcern(WriteConcern.MAJORITY).withReadPreference(ReadPreference.primary()); .withReadPreference(ReadPreference.primary());
database.getCollection(collectionName).deleteMany(new Document()); database.getCollection(collectionName).deleteMany(new Document());
} }
@ -267,19 +273,36 @@ public class MongoTestUtils {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public static boolean isVectorSearchEnabled() { public static boolean isVectorSearchEnabled() {
try (MongoClient client = MongoTestUtils.client()) { try (MongoClient client = MongoTestUtils.client()) {
return isVectorSearchEnabled(client);
}
}
public static boolean isVectorSearchEnabled(MongoClient client) {
try {
return client.getDatabase("admin").runCommand(new Document("getCmdLineOpts", "1")).get("argv", List.class) return client.getDatabase("admin").runCommand(new Document("getCmdLineOpts", "1")).get("argv", List.class)
.stream().anyMatch(it -> { .stream().anyMatch(it -> {
if(it instanceof String cfgString) { if (it instanceof String cfgString) {
return cfgString.startsWith("searchIndexManagementHostAndPort"); return cfgString.startsWith("searchIndexManagementHostAndPort");
} }
return false; return false;
}); });
} catch (Exception e) { } catch (Exception e) {
return false; return false;
} }
} }
public static boolean isSearchIndexReady(MongoClient client, @Nullable String database, String collectionName) {
try {
MongoCollection<Document> collection = client.getDatabase(StringUtils.hasText(database) ? database : "test").getCollection(collectionName);
collection.aggregate(List.of(new Document("$listSearchIndexes", new Document())));
} catch (Exception e) {
return false;
}
return true;
}
public static Duration getTimeout() { public static Duration getTimeout() {
return ObjectUtils.nullSafeEquals("jenkins", ENV.getProperty("user.name")) ? Duration.ofMillis(100) return ObjectUtils.nullSafeEquals("jenkins", ENV.getProperty("user.name")) ? Duration.ofMillis(100)
@ -297,10 +320,11 @@ public class MongoTestUtils {
public static CollectionInfo readCollectionInfo(MongoDatabase db, String collectionName) { public static CollectionInfo readCollectionInfo(MongoDatabase db, String collectionName) {
List<Document> list = db.runCommand(new Document().append("listCollections", 1).append("filter", new Document("name", collectionName))) List<Document> list = db
.runCommand(new Document().append("listCollections", 1).append("filter", new Document("name", collectionName)))
.get("cursor", Document.class).get("firstBatch", List.class); .get("cursor", Document.class).get("firstBatch", List.class);
if(list.isEmpty()) { if (list.isEmpty()) {
throw new IllegalStateException(String.format("Collection %s not found.", collectionName)); throw new IllegalStateException(String.format("Collection %s not found.", collectionName));
} }
return CollectionInfo.from(list.get(0)); return CollectionInfo.from(list.get(0));

3
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/ReplSetClient.java

@ -21,6 +21,8 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
* Marks a field or method as to be autowired by JUnit's dependency injection facilities for injection of a MongoDB * Marks a field or method as to be autowired by JUnit's dependency injection facilities for injection of a MongoDB
* client instance connected to a replica set. Depends on {@link MongoClientExtension}. * client instance connected to a replica set. Depends on {@link MongoClientExtension}.
@ -34,6 +36,7 @@ import java.lang.annotation.Target;
@Target({ ElementType.FIELD, ElementType.PARAMETER }) @Target({ ElementType.FIELD, ElementType.PARAMETER })
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Documented @Documented
@ExtendWith(MongoClientExtension.class)
public @interface ReplSetClient { public @interface ReplSetClient {
} }

Loading…
Cancel
Save