Browse Source

Add support for generating VectorSearch queries during AOT.

See #5004
Original pull request: #5005
pull/5026/head
Christoph Strobl 6 months ago committed by Mark Paluch
parent
commit
053158b8d8
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 18
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java
  2. 86
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java
  3. 31
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java
  4. 32
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java
  5. 11
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java
  6. 48
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java
  7. 211
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java
  8. 3
      spring-data-mongodb/src/test/java/example/aot/User.java
  9. 30
      spring-data-mongodb/src/test/java/example/aot/UserRepository.java
  10. 3
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java
  11. 172
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java
  12. 144
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java
  13. 2
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java
  14. 5
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java
  15. 39
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java
  16. 36
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java

18
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java

@ -15,6 +15,7 @@ @@ -15,6 +15,7 @@
*/
package org.springframework.data.mongodb.repository.aot;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
@ -46,6 +47,7 @@ import org.springframework.data.mongodb.core.query.CriteriaDefinition.Placeholde @@ -46,6 +47,7 @@ import org.springframework.data.mongodb.core.query.CriteriaDefinition.Placeholde
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.TextCriteria;
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor;
import org.springframework.data.mongodb.repository.query.MongoParameterAccessor;
import org.springframework.data.mongodb.repository.query.MongoQueryCreator;
@ -79,14 +81,16 @@ class AotQueryCreator { @@ -79,14 +81,16 @@ class AotQueryCreator {
}
@SuppressWarnings("NullAway")
StringQuery createQuery(PartTree partTree, QueryMethod queryMethod) {
StringQuery createQuery(PartTree partTree, QueryMethod queryMethod, Method source) {
boolean geoNear = queryMethod instanceof MongoQueryMethod mqm ? mqm.isGeoNearQuery() : false;
boolean searchQuery = queryMethod instanceof MongoQueryMethod mqm
? mqm.isSearchQuery() || source.isAnnotationPresent(VectorSearch.class)
: source.isAnnotationPresent(VectorSearch.class);
Query query = new MongoQueryCreator(partTree,
new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext, geoNear, queryMethod.isSearchQuery())
.createQuery();
new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext,
geoNear, searchQuery).createQuery();
if (partTree.isLimiting()) {
query.limit(partTree.getMaxResults());
@ -141,8 +145,7 @@ class AotQueryCreator { @@ -141,8 +145,7 @@ class AotQueryCreator {
for (Parameter parameter : parameters.toList()) {
if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) {
placeholders.add(parameter.getIndex(), new GeoJsonPlaceholder(parameter.getIndex(), ""));
}
else if (ClassUtils.isAssignable(Point.class, parameter.getType())) {
} else if (ClassUtils.isAssignable(Point.class, parameter.getType())) {
placeholders.add(parameter.getIndex(), new PointPlaceholder(parameter.getIndex()));
} else if (ClassUtils.isAssignable(Circle.class, parameter.getType())) {
placeholders.add(parameter.getIndex(), new CirclePlaceholder(parameter.getIndex()));
@ -152,8 +155,7 @@ class AotQueryCreator { @@ -152,8 +155,7 @@ class AotQueryCreator {
placeholders.add(parameter.getIndex(), new SpherePlaceholder(parameter.getIndex()));
} else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) {
placeholders.add(parameter.getIndex(), new PolygonPlaceholder(parameter.getIndex()));
}
else {
} else {
placeholders.add(parameter.getIndex(), Placeholder.indexed(parameter.getIndex()));
}
}

86
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java

@ -16,12 +16,17 @@ @@ -16,12 +16,17 @@
package org.springframework.data.mongodb.repository.aot;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Consumer;
import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.Score;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.expression.ValueEvaluationContext;
import org.springframework.data.expression.ValueExpression;
import org.springframework.data.mapping.model.ValueExpressionEvaluator;
@ -33,6 +38,7 @@ import org.springframework.data.mongodb.core.convert.MongoConverter; @@ -33,6 +38,7 @@ import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.mapping.FieldName;
import org.springframework.data.mongodb.core.query.BasicQuery;
import org.springframework.data.mongodb.core.query.Collation;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.repository.query.MongoParameters;
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec;
@ -42,7 +48,9 @@ import org.springframework.data.repository.core.RepositoryMetadata; @@ -42,7 +48,9 @@ import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.data.repository.query.ValueExpressionDelegate;
import org.springframework.expression.EvaluationContext;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ObjectUtils;
/**
@ -108,7 +116,27 @@ public class MongoAotRepositoryFragmentSupport { @@ -108,7 +116,27 @@ public class MongoAotRepositoryFragmentSupport {
return new ParameterBindingDocumentCodec().decode(source, bindingContext);
}
protected Object evaluate(String source, Map<String, Object> parameters) {
protected Object[] arguments(Object... arguments) {
return arguments;
}
protected Map<String, Object> argumentMap(Object... parameters) {
Assert.state(parameters.length % 2 == 0, "even number of args required");
LinkedHashMap<String, Object> argumentMap = CollectionUtils.newLinkedHashMap(parameters.length / 2);
for (int i = 0; i < parameters.length; i += 2) {
if (!(parameters[i] instanceof String key)) {
throw new IllegalArgumentException("key must be a String");
}
argumentMap.put(key, parameters[i + 1]);
}
return argumentMap;
}
protected @Nullable Object evaluate(String source, Map<String, Object> parameters) {
ValueEvaluationContext valueEvaluationContext = this.valueExpressionDelegate.getEvaluationContextAccessor()
.create(new NoMongoParameters()).getEvaluationContext(parameters.values());
@ -120,9 +148,63 @@ public class MongoAotRepositoryFragmentSupport { @@ -120,9 +148,63 @@ public class MongoAotRepositoryFragmentSupport {
return parse.evaluate(valueEvaluationContext);
}
protected Consumer<Criteria> scoreBetween(Range.Bound<? extends Score> lower, Range.Bound<? extends Score> upper) {
return criteria -> {
if (lower.isBounded()) {
double value = lower.getValue().get().getValue();
if (lower.isInclusive()) {
criteria.gte(value);
} else {
criteria.gt(value);
}
}
if (upper.isBounded()) {
double value = upper.getValue().get().getValue();
if (upper.isInclusive()) {
criteria.lte(value);
} else {
criteria.lt(value);
}
}
};
}
protected ScoringFunction scoringFunction(Range<? extends Score> scoreRange) {
if (scoreRange != null) {
if (scoreRange.getUpperBound().isBounded()) {
return scoreRange.getUpperBound().getValue().get().getFunction();
}
if (scoreRange.getLowerBound().isBounded()) {
return scoreRange.getLowerBound().getValue().get().getFunction();
}
}
return ScoringFunction.unspecified();
}
// Range<Score> scoreRange = accessor.getScoreRange();
//
// if (scoreRange != null) {
// if (scoreRange.getUpperBound().isBounded()) {
// return scoreRange.getUpperBound().getValue().get().getFunction();
// }
//
// if (scoreRange.getLowerBound().isBounded()) {
// return scoreRange.getLowerBound().getValue().get().getFunction();
// }
// }
//
// return ScoringFunction.unspecified();
protected Collation collationOf(@Nullable Object source) {
if(source == null) {
if (source == null) {
return Collation.simple();
}
if (source instanceof String) {

31
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

@ -37,6 +37,7 @@ import org.springframework.data.mongodb.repository.query.MongoQueryMethod; @@ -37,6 +37,7 @@ import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.util.NumberUtils;
import org.springframework.util.StringUtils;
/**
@ -49,6 +50,7 @@ class MongoCodeBlocks { @@ -49,6 +50,7 @@ class MongoCodeBlocks {
private static final Pattern PARAMETER_BINDING_PATTERN = Pattern.compile("\\?(\\d+)");
private static final Pattern EXPRESSION_BINDING_PATTERN = Pattern.compile("[\\?:][#$]\\{.*\\}");
private static final Pattern VALUE_EXPRESSION_PATTERN = Pattern.compile("^#\\{.*}$");
/**
* Builder for generating query parsing {@link CodeBlock}.
@ -179,7 +181,7 @@ class MongoCodeBlocks { @@ -179,7 +181,7 @@ class MongoCodeBlocks {
} else {
builder.add("$T $L = bindParameters($S, ", Document.class, variableName, source);
if (containsNamedPlaceholder(source)) {
renderArgumentMap(arguments);
builder.add(renderArgumentMap(arguments));
} else {
builder.add(renderArgumentArray(arguments));
}
@ -191,7 +193,7 @@ class MongoCodeBlocks { @@ -191,7 +193,7 @@ class MongoCodeBlocks {
static CodeBlock renderArgumentMap(Map<String, CodeBlock> arguments) {
Builder builder = CodeBlock.builder();
builder.add("$T.of(", Map.class);
builder.add("argumentMap(");
Iterator<Entry<String, CodeBlock>> iterator = arguments.entrySet().iterator();
while (iterator.hasNext()) {
Entry<String, CodeBlock> next = iterator.next();
@ -208,24 +210,41 @@ class MongoCodeBlocks { @@ -208,24 +210,41 @@ class MongoCodeBlocks {
static CodeBlock renderArgumentArray(Map<String, CodeBlock> arguments) {
Builder builder = CodeBlock.builder();
builder.add("new $T[]{ ", Object.class);
builder.add("arguments(");
Iterator<CodeBlock> iterator = arguments.values().iterator();
while (iterator.hasNext()) {
builder.add(iterator.next());
if (iterator.hasNext()) {
builder.add(", ");
} else {
builder.add(" ");
}
}
builder.add("}");
builder.add(")");
return builder.build();
}
static CodeBlock evaluateNumberPotentially(String value, Class<? extends Number> targetType,
Map<String, CodeBlock> arguments) {
try {
Number number = NumberUtils.parseNumber(value, targetType);
return CodeBlock.of("$L", number);
} catch (IllegalArgumentException e) {
Builder builder = CodeBlock.builder();
builder.add("($T) evaluate($S, ", targetType, value);
builder.add(MongoCodeBlocks.renderArgumentMap(arguments));
builder.add(")");
return builder.build();
}
}
static boolean containsPlaceholder(String source) {
return containsIndexedPlaceholder(source) || containsNamedPlaceholder(source);
}
static boolean containsExpression(String source) {
return VALUE_EXPRESSION_PATTERN.matcher(source).find();
}
static boolean containsNamedPlaceholder(String source) {
return EXPRESSION_BINDING_PATTERN.matcher(source).find();
}

32
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

@ -37,6 +37,7 @@ import org.springframework.data.mongodb.core.aggregation.AggregationUpdate; @@ -37,6 +37,7 @@ import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.repository.Query;
import org.springframework.data.mongodb.repository.Update;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotRepositoryClassBuilder;
import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder;
@ -107,7 +108,11 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -107,7 +108,11 @@ public class MongoRepositoryContributor extends RepositoryContributor {
}
QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod,
AnnotatedElementUtils.findMergedAnnotation(method, Query.class));
AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method);
if (queryMethod.isSearchQuery() || method.isAnnotationPresent(VectorSearch.class)) {
return searchMethodContributor(queryMethod, new SearchInteraction(query.getQuery()));
}
if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1
&& queryMethod.getReturnType().isCollectionLike())) {
@ -126,8 +131,8 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -126,8 +131,8 @@ public class MongoRepositoryContributor extends RepositoryContributor {
UpdateInteraction update = new UpdateInteraction(query, null, updateIndex);
return updateMethodContributor(queryMethod, update);
} else {
Update updateSource = queryMethod.getUpdateSource();
if (StringUtils.hasText(updateSource.value())) {
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()), null);
@ -146,7 +151,7 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -146,7 +151,7 @@ public class MongoRepositoryContributor extends RepositoryContributor {
@SuppressWarnings("NullAway")
private QueryInteraction createStringQuery(RepositoryInformation repositoryInformation, MongoQueryMethod queryMethod,
@Nullable Query queryAnnotation) {
@Nullable Query queryAnnotation, Method source) {
QueryInteraction query;
if (queryMethod.hasAnnotatedQuery() && queryAnnotation != null) {
@ -155,8 +160,8 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -155,8 +160,8 @@ public class MongoRepositoryContributor extends RepositoryContributor {
} else {
PartTree partTree = new PartTree(queryMethod.getName(), repositoryInformation.getDomainType());
query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod), partTree.isCountProjection(),
partTree.isDelete(), partTree.isExistsProjection());
query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod, source),
partTree.isCountProjection(), partTree.isDelete(), partTree.isExistsProjection());
}
if (queryAnnotation != null && StringUtils.hasText(queryAnnotation.sort())) {
@ -172,7 +177,7 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -172,7 +177,7 @@ public class MongoRepositoryContributor extends RepositoryContributor {
private static boolean backoff(MongoQueryMethod method) {
// TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays.
boolean skip = method.isSearchQuery() || method.getReturnType().getType().isArray();
boolean skip = method.getReturnType().getType().isArray();
if (skip && logger.isDebugEnabled()) {
logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query"
@ -220,6 +225,21 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -220,6 +225,21 @@ public class MongoRepositoryContributor extends RepositoryContributor {
});
}
static MethodContributor<MongoQueryMethod> searchMethodContributor(MongoQueryMethod queryMethod,
SearchInteraction interaction) {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(interaction).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder();
String variableName = "search";
builder.add(new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod)
.usingVariableName(variableName).withFilter(interaction.getFilter()).build());
return builder.build();
});
}
static MethodContributor<MongoQueryMethod> updateMethodContributor(MongoQueryMethod queryMethod,
UpdateInteraction update) {

11
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java

@ -211,8 +211,7 @@ class QueryBlocks { @@ -211,8 +211,7 @@ class QueryBlocks {
Builder builder = CodeBlock.builder();
builder.add("\n");
builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName));
builder.add(buildJustTheQuery());
if (StringUtils.hasText(source.getQuery().getFieldsString())) {
@ -289,6 +288,14 @@ class QueryBlocks { @@ -289,6 +288,14 @@ class QueryBlocks {
return builder.build();
}
CodeBlock buildJustTheQuery() {
Builder builder = CodeBlock.builder();
builder.add("\n");
builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName));
return builder.build();
}
private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) {
Builder builder = CodeBlock.builder();

48
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java

@ -0,0 +1,48 @@ @@ -0,0 +1,48 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import java.util.Map;
import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.aot.generate.QueryMetadata;
/**
* @author Christoph Strobl
*/
public class SearchInteraction extends MongoInteraction implements QueryMetadata {
StringQuery filter;
public SearchInteraction(StringQuery filter) {
this.filter = filter;
}
public StringQuery getFilter() {
return filter;
}
@Override
InteractionType getExecutionType() {
return InteractionType.AGGREGATION;
}
@Override
public Map<String, Object> serialize() {
return Map.of("FIXME", "please!");
}
}

211
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java

@ -0,0 +1,211 @@ @@ -0,0 +1,211 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import java.lang.reflect.Field;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.bson.Document;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.MongoQueryExecution.VectorSearchExecution;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.data.util.TypeInformation;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.util.StringUtils;
/**
* @author Christoph Strobl
* @since 5.0
*/
class VectorSearchBocks {
static class VectorSearchQueryCodeBlockBuilder {
private final AotQueryMethodGenerationContext context;
private final MongoQueryMethod queryMethod;
private String searchQueryVariableName;
private StringQuery filter;
private final Map<String, CodeBlock> arguments;
VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
this.context = context;
this.queryMethod = queryMethod;
this.arguments = new LinkedHashMap<>();
context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it)));
}
VectorSearchQueryCodeBlockBuilder usingVariableName(String searchQueryVariableName) {
this.searchQueryVariableName = searchQueryVariableName;
return this;
}
CodeBlock build() {
Builder builder = CodeBlock.builder();
String vectorParameterName = context.getVectorParameterName();
MergedAnnotation<VectorSearch> annotation = context.getAnnotation(VectorSearch.class);
String searchPath = annotation.getString("path");
String indexName = annotation.getString("indexName");
String numCandidates = annotation.getString("numCandidates");
SearchType searchType = annotation.getEnum("searchType", SearchType.class);
String limit = annotation.getString("limit");
if (!StringUtils.hasText(searchPath)) { // FIXME: somehow duplicate logic of AnnotatedQueryFactory
Field[] declaredFields = context.getRepositoryInformation().getDomainType().getDeclaredFields();
for (Field field : declaredFields) {
if (Vector.class.isAssignableFrom(field.getType())) {
searchPath = field.getName();
break;
}
}
}
String vectorSearchVar = context.localVariable("$vectorSearch");
builder.add("$T $L = $T.vectorSearch($S).path($S).vector($L)", VectorSearchOperation.class, vectorSearchVar,
Aggregation.class, indexName, searchPath, vectorParameterName);
if (StringUtils.hasText(context.getLimitParameterName())) {
builder.add(".limit($L);\n", context.getLimitParameterName());
} else if (filter.isLimited()) {
builder.add(".limit($L);\n", filter.getLimit());
} else if (StringUtils.hasText(limit)) {
if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) {
builder.add(".limit(");
builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments));
builder.add(");\n");
} else {
builder.add(".limit($L);\n", limit);
}
} else {
builder.add(".limit($T.unlimited());\n", Limit.class);
}
if (!searchType.equals(SearchType.DEFAULT)) {
builder.addStatement("$1L = $1L.searchType($2T.$3L)", vectorSearchVar, SearchType.class, searchType.name());
}
if (StringUtils.hasText(numCandidates)) {
builder.add("$1L = $1L.numCandidates(", vectorSearchVar);
builder.add(MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, arguments));
builder.add(");\n");
} else if (searchType == VectorSearchOperation.SearchType.ANN
|| searchType == VectorSearchOperation.SearchType.DEFAULT) {
builder.add(
"// MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return\n");
if (StringUtils.hasText(context.getLimitParameterName())) {
builder.addStatement("$1L = $1L.numCandidates($2L.max() * 20)", vectorSearchVar,
context.getLimitParameterName());
} else if (StringUtils.hasText(limit)) {
if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) {
builder.add("$1L = $1L.numCandidates((", vectorSearchVar);
builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments));
builder.add(") * 20);\n");
} else {
builder.addStatement("$1L = $1L.numCandidates($2L * 20)", vectorSearchVar, limit);
}
} else {
builder.addStatement("$1L = $1L.numCandidates($2L)", vectorSearchVar, filter.getLimit() * 20);
}
}
builder.addStatement("$1L = $1L.withSearchScore(\"__score__\")", vectorSearchVar);
if (StringUtils.hasText(context.getScoreParameterName())) {
String scoreCriteriaVar = context.localVariable("criteria");
builder.addStatement("$1L = $1L.withFilterBySore($2L -> { $2L.gt($3L.getValue()); })", vectorSearchVar,
scoreCriteriaVar, context.getScoreParameterName());
} else if (StringUtils.hasText(context.getScoreRangeParameterName())) {
builder.addStatement("$1L = $1L.withFilterBySore(scoreBetween($2L.getLowerBound(), $2L.getUpperBound()))",
vectorSearchVar, context.getScoreRangeParameterName());
}
if (StringUtils.hasText(filter.getQueryString())) {
String filterVar = context.localVariable("filter");
builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter")
.filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery());
builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar);
builder.add("\n");
}
String sortStageVar = context.localVariable("$sort");
if(filter.isSorted()) {
builder.add("$T $L = (_ctx) -> {\n", AggregationOperation.class, sortStageVar);
builder.indent();
builder.addStatement("$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class)", Document.class, filter.getSortString(), context.getActualReturnType().getType());
builder.addStatement("return new $T($S, _mappedSort.append(\"__score__\", -1))", Document.class, "$sort");
builder.unindent();
builder.add("};");
} else {
builder.addStatement("var $L = $T.sort($T.Direction.DESC, $S)", sortStageVar, Aggregation.class, Sort.class, "__score__");
}
builder.add("\n");
builder.addStatement("$1T $2L = new $1T($3T.of($4L, $5L))", AggregationPipeline.class, searchQueryVariableName,
List.class, vectorSearchVar, sortStageVar);
String scoringFunctionVar = context.localVariable("scoringFunction");
builder.add("$1T $2L = ", ScoringFunction.class, scoringFunctionVar);
if (StringUtils.hasText(context.getScoreParameterName())) {
builder.add("$L.getFunction();\n", context.getScoreParameterName());
} else if (StringUtils.hasText(context.getScoreRangeParameterName())) {
builder.add("scoringFunction($L);\n", context.getScoreRangeParameterName());
} else {
builder.add("$1T.unspecified();\n", ScoringFunction.class);
}
builder.addStatement(
"return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)",
VectorSearchExecution.class, context.fieldNameOf(MongoOperations.class),
context.getRepositoryInformation().getDomainType(), TypeInformation.class,
queryMethod.getReturnType().getType(), searchQueryVariableName, scoringFunctionVar);
return builder.build();
}
public VectorSearchQueryCodeBlockBuilder withFilter(StringQuery filter) {
this.filter = filter;
return this;
}
}
}

3
spring-data-mongodb/src/test/java/example/aot/User.java

@ -17,6 +17,7 @@ package example.aot; @@ -17,6 +17,7 @@ package example.aot;
import java.time.Instant;
import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.mapping.Field;
/**
@ -38,6 +39,8 @@ public class User { @@ -38,6 +39,8 @@ public class User {
Instant lastSeen;
Long visits;
Vector embedding;
public String getId() {
return id;
}

30
spring-data-mongodb/src/test/java/example/aot/UserRepository.java

@ -30,9 +30,13 @@ import org.springframework.data.domain.Limit; @@ -30,9 +30,13 @@ import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.Score;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.SearchResults;
import org.springframework.data.domain.Similarity;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Vector;
import org.springframework.data.domain.Window;
import org.springframework.data.geo.Box;
import org.springframework.data.geo.Circle;
@ -43,6 +47,7 @@ import org.springframework.data.geo.GeoResults; @@ -43,6 +47,7 @@ import org.springframework.data.geo.GeoResults;
import org.springframework.data.geo.Point;
import org.springframework.data.geo.Polygon;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
import org.springframework.data.mongodb.core.geo.GeoJson;
import org.springframework.data.mongodb.core.geo.GeoJsonPolygon;
import org.springframework.data.mongodb.core.geo.Sphere;
@ -51,6 +56,7 @@ import org.springframework.data.mongodb.repository.Hint; @@ -51,6 +56,7 @@ import org.springframework.data.mongodb.repository.Hint;
import org.springframework.data.mongodb.repository.Query;
import org.springframework.data.mongodb.repository.ReadPreference;
import org.springframework.data.mongodb.repository.Update;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.query.Param;
@ -291,6 +297,30 @@ public interface UserRepository extends CrudRepository<User, String> { @@ -291,6 +297,30 @@ public interface UserRepository extends CrudRepository<User, String> {
"{ '$project': { '_id' : '$last_name' } }" }, collation = "no_collation")
List<String> findAllLastnamesWithCollation();
// Vector Search
@VectorSearch(indexName = "embedding.vector_cos", filter = "{lastname: ?0}", numCandidates = "#{10+10}",
searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<User> annotatedVectorSearch(String lastname, Vector vector, Score distance, Limit limit);
@VectorSearch(indexName = "embedding.vector_cos")
SearchResults<User> searchCosineByLastnameAndEmbeddingNear(String lastname, Vector vector, Score similarity,
Limit limit);
@VectorSearch(indexName = "embedding.vector_cos")
List<User> searchAsListByLastnameAndEmbeddingNear(String lastname, Vector vector, Limit limit);
@VectorSearch(indexName = "embedding.vector_cos", limit = "10")
SearchResults<User> searchByLastnameAndEmbeddingWithin(String lastname, Vector vector, Range<Similarity> distance);
@VectorSearch(indexName = "embedding.vector_cos", limit = "10")
SearchResults<User> searchByLastnameAndEmbeddingWithinOrderByFirstname(String lastname, Vector vector,
Range<Similarity> distance);
@VectorSearch(indexName = "embedding.vector_cos")
SearchResults<User> searchTop1ByLastnameAndEmbeddingWithin(String lastname, Vector vector,
Range<Similarity> distance);
class UserAggregate {
@Id //

3
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java

@ -77,9 +77,8 @@ class VectorIndexIntegrationTests { @@ -77,9 +77,8 @@ class VectorIndexIntegrationTests {
@AfterEach
void cleanup() {
template.tryToDropSearchIndexes(Movie.class, Duration.ofSeconds(30));
template.flush(Movie.class);
template.searchIndexOps(Movie.class).dropAllIndexes();
template.awaitNoSearchIndexAvailable(Movie.class, Duration.ofSeconds(30));
}
@ParameterizedTest // GH-4706

172
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java

@ -24,12 +24,15 @@ import example.aot.UserProjection; @@ -24,12 +24,15 @@ import example.aot.UserProjection;
import example.aot.UserRepository;
import example.aot.UserRepository.UserAggregate;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.regex.Pattern;
import org.bson.BsonString;
import org.bson.Document;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -42,9 +45,13 @@ import org.springframework.data.domain.OffsetScrollPosition; @@ -42,9 +45,13 @@ import org.springframework.data.domain.OffsetScrollPosition;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.Score;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.SearchResults;
import org.springframework.data.domain.Similarity;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Vector;
import org.springframework.data.domain.Window;
import org.springframework.data.geo.Box;
import org.springframework.data.geo.Circle;
@ -60,13 +67,22 @@ import org.springframework.data.mongodb.core.MongoTemplate; @@ -60,13 +67,22 @@ import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.geo.GeoJsonPoint;
import org.springframework.data.mongodb.core.geo.GeoJsonPolygon;
import org.springframework.data.mongodb.test.util.Client;
import org.springframework.data.mongodb.test.util.AtlasContainer;
import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable;
import org.springframework.data.mongodb.test.util.MongoTestUtils;
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
import org.springframework.util.StringUtils;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.shaded.org.awaitility.Awaitility;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoCursor;
import com.mongodb.client.model.IndexOptions;
import com.mongodb.client.model.SearchIndexModel;
import com.mongodb.client.model.SearchIndexType;
/**
* Integration tests for the {@link UserRepository} AOT fragment.
@ -74,13 +90,15 @@ import com.mongodb.client.model.IndexOptions; @@ -74,13 +90,15 @@ import com.mongodb.client.model.IndexOptions;
* @author Christoph Strobl
* @author Mark Paluch
*/
@Testcontainers(disabledWithoutDocker = true)
@SpringJUnitConfig(classes = MongoRepositoryContributorTests.MongoRepositoryContributorConfiguration.class)
class MongoRepositoryContributorTests {
private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch();
private static final String DB_NAME = "aot-repo-tests";
private static final String COLLECTION_NAME = "user";
@Client static MongoClient client;
static MongoClient client;
@Autowired UserRepository fragment;
@Configuration
@ -97,16 +115,63 @@ class MongoRepositoryContributorTests { @@ -97,16 +115,63 @@ class MongoRepositoryContributorTests {
}
@BeforeAll
static void beforeAll() {
client.getDatabase(DB_NAME).getCollection("user").createIndex(new Document("location.coordinates", "2d"),
new IndexOptions());
static void beforeAll() throws InterruptedException {
client = MongoClients.create(atlasLocal.getConnectionString());
MongoCollection<Document> userCollection = client.getDatabase(DB_NAME).getCollection(COLLECTION_NAME);
userCollection.createIndex(new Document("location.coordinates", "2d"), new IndexOptions());
userCollection.createIndex(new Document("location.coordinates", "2dsphere"), new IndexOptions());
Thread.sleep(250); // just wait a little or the index will be broken
}
/**
* Create the vector search index and wait till it is queryable and actually serving data. Since this may slow down
* tests quite a bit, better call it only when needed to run certain tests.
*/
private static void initializeVectorIndex() {
String indexName = "embedding.vector_cos";
Document searchIndex = new Document("fields",
List.of(new Document("type", "vector").append("path", "embedding").append("numDimensions", 5)
.append("similarity", "cosine"), new Document("type", "filter").append("path", "last_name")));
MongoCollection<Document> userCollection = client.getDatabase(DB_NAME).getCollection(COLLECTION_NAME);
userCollection.createSearchIndexes(
List.of(new SearchIndexModel(indexName, searchIndex, SearchIndexType.of(new BsonString("vectorSearch")))));
// wait for search index to be queryable
Awaitility.await().atMost(Duration.ofSeconds(120)).pollInterval(Duration.ofMillis(200)).until(() -> {
return MongoTestUtils.isSearchIndexReady(indexName, client, DB_NAME, COLLECTION_NAME);
});
Document $vectorSearch = new Document("$vectorSearch",
new Document("index", indexName).append("limit", 1).append("numCandidates", 20).append("path", "embedding")
.append("queryVector", List.of(1.0, 1.12345, 2.23456, 3.34567, 4.45678)));
// wait for search index to serve data
Awaitility.await().atLeast(Duration.ofMillis(50)).atMost(Duration.ofSeconds(120)).ignoreExceptions()
.pollInterval(Duration.ofMillis(250)).until(() -> {
try (MongoCursor<Document> cursor = userCollection.aggregate(List.of($vectorSearch)).iterator()) {
if (cursor.hasNext()) {
Document next = cursor.next();
return true;
}
return false;
}
});
}
@BeforeEach
void beforeEach() {
void beforeEach() throws InterruptedException {
initUsers();
}
@AfterEach
void afterEach() {
MongoTestUtils.flushCollection(DB_NAME, "user", client);
initUsers();
}
@Test
@ -747,10 +812,86 @@ class MongoRepositoryContributorTests { @@ -747,10 +812,86 @@ class MongoRepositoryContributorTests {
assertThat(page2.hasNext()).isFalse();
}
@Test
@EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class)
void vectorSearchFromAnnotation() {
initializeVectorIndex();
Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d);
SearchResults<User> results = fragment.annotatedVectorSearch("Skywalker", vector, Score.of(0.99), Limit.of(10));
assertThat(results).hasSize(1);
}
@Test
@EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class)
void vectorSearchWithDerivedQuery() {
initializeVectorIndex();
Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d);
SearchResults<User> results = fragment.searchCosineByLastnameAndEmbeddingNear("Skywalker", vector, Score.of(0.98),
Limit.of(10));
assertThat(results).hasSize(1);
}
@Test
@EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class)
void vectorSearchReturningResultsAsList() {
initializeVectorIndex();
Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d);
List<User> results = fragment.searchAsListByLastnameAndEmbeddingNear("Skywalker", vector, Limit.of(10));
assertThat(results).hasSize(2);
}
@Test
@EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class)
void vectorSearchWithLimitFromAnnotation() {
initializeVectorIndex();
Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d);
SearchResults<User> results = fragment.searchByLastnameAndEmbeddingWithin("Skywalker", vector,
Similarity.between(0.4, 0.99));
assertThat(results).hasSize(1);
}
@Test
@EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class)
void vectorSearchWithSorting() {
initializeVectorIndex();
Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d);
SearchResults<User> results = fragment.searchByLastnameAndEmbeddingWithinOrderByFirstname("Skywalker", vector,
Similarity.between(0.4, 1.0));
assertThat(results).hasSize(2);
}
@Test
@EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class)
void vectorSearchWithLimitFromDerivedQuery() {
initializeVectorIndex();
Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d);
SearchResults<User> results = fragment.searchTop1ByLastnameAndEmbeddingWithin("Skywalker", vector,
Similarity.between(0.4, 1.0));
assertThat(results).hasSize(1);
}
/**
* GeoResults<Person> results = repository.findPersonByLocationNear(new Point(-73.99, 40.73), range);
*/
private static void initUsers() {
private static void initUsers() throws InterruptedException {
Document luke = Document.parse("""
{
@ -770,6 +911,7 @@ class MongoRepositoryContributorTests { @@ -770,6 +911,7 @@ class MongoRepositoryContributorTests {
}
}
],
"embedding" : [1.00000, 1.12345, 2.23456, 3.34567, 4.45678],
"_class": "example.springdata.aot.User"
}""");
@ -785,6 +927,7 @@ class MongoRepositoryContributorTests { @@ -785,6 +927,7 @@ class MongoRepositoryContributorTests {
"x" : -73.99171, "y" : 40.738868
}
},
"embedding" : [1.0001, 2.12345, 3.23456, 4.34567, 5.45678],
"_class": "example.springdata.aot.User"
}""");
@ -802,6 +945,7 @@ class MongoRepositoryContributorTests { @@ -802,6 +945,7 @@ class MongoRepositoryContributorTests {
}
}
],
"embedding" : [2.0002, 3.12345, 4.23456, 5.34567, 6.45678],
"_class": "example.springdata.aot.User"
}""");
@ -812,6 +956,7 @@ class MongoRepositoryContributorTests { @@ -812,6 +956,7 @@ class MongoRepositoryContributorTests {
"lastSeen" : {
"$date": "2025-01-01T00:00:00.000Z"
},
"embedding" : [3.0003, 4.12345, 5.23456, 6.34567, 7.45678],
"_class": "example.springdata.aot.User"
}""");
@ -834,7 +979,8 @@ class MongoRepositoryContributorTests { @@ -834,7 +979,8 @@ class MongoRepositoryContributorTests {
"$date": "2025-01-15T13:53:33.855Z"
}
}
]
],
"embedding" : [4.0004, 5.12345, 6.23456, 7.34567, 8.45678]
}""");
Document vader = Document.parse("""
@ -857,7 +1003,8 @@ class MongoRepositoryContributorTests { @@ -857,7 +1003,8 @@ class MongoRepositoryContributorTests {
"$date": "2025-01-15T13:46:33.855Z"
}
}
]
],
"embedding" : [5.0005, 6.12345, 7.23456, 8.34567, 9.45678]
}""");
Document kylo = Document.parse("""
@ -865,7 +1012,8 @@ class MongoRepositoryContributorTests { @@ -865,7 +1012,8 @@ class MongoRepositoryContributorTests {
"_id": "id-7",
"username": "kylo",
"first_name": "Ben",
"last_name": "Solo"
"last_name": "Solo",
"embedding" : [6.0006, 7.12345, 8.23456, 9.34567, 10.45678]
}
""");

144
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java

@ -29,8 +29,13 @@ import javax.lang.model.element.Modifier; @@ -29,8 +29,13 @@ import javax.lang.model.element.Modifier;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.Score;
import org.springframework.data.domain.SearchResults;
import org.springframework.data.domain.Similarity;
import org.springframework.data.domain.Vector;
import org.springframework.data.geo.Box;
import org.springframework.data.geo.Circle;
import org.springframework.data.geo.Distance;
@ -43,6 +48,7 @@ import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; @@ -43,6 +48,7 @@ import org.springframework.data.mongodb.core.geo.GeoJsonPolygon;
import org.springframework.data.mongodb.core.geo.Sphere;
import org.springframework.data.mongodb.repository.Hint;
import org.springframework.data.mongodb.repository.ReadPreference;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.repository.Repository;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata;
@ -65,7 +71,7 @@ public class QueryMethodContributionUnitTests { @@ -65,7 +71,7 @@ public class QueryMethodContributionUnitTests {
assertThat(methodSpec.toString()) //
.contains("{'location.coordinates':{'$near':?0}}") //
.contains("Object[]{ location }") //
.contains("arguments(location)") //
.contains("return finder.matching(filterQuery).all()");
}
@ -124,7 +130,7 @@ public class QueryMethodContributionUnitTests { @@ -124,7 +130,7 @@ public class QueryMethodContributionUnitTests {
assertThat(methodSpec.toString()) //
.contains("{'location.coordinates':{'$geoWithin':{'$geometry':?0}}") //
.contains("Object[]{ polygon }") //
.contains("arguments(polygon)") //
.contains("return finder.matching(filterQuery).all()");
}
@ -182,7 +188,7 @@ public class QueryMethodContributionUnitTests { @@ -182,7 +188,7 @@ public class QueryMethodContributionUnitTests {
assertThat(methodSpec.toString()) //
.contains("NearQuery.near(point)") //
.contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") //
.contains("filterQuery = createQuery(\"{'lastname':?0}\", new java.lang.Object[]{ lastname })") //
.contains("filterQuery = createQuery(\"{'lastname':?0}\", arguments(lastname))") //
.contains("nearQuery.query(filterQuery)") //
.contains(".near(nearQuery)") //
.contains("return nearFinder.all()");
@ -194,8 +200,7 @@ public class QueryMethodContributionUnitTests { @@ -194,8 +200,7 @@ public class QueryMethodContributionUnitTests {
MethodSpec methodSpec = codeOf(UserRepository.class, "findWithExpressionUsingParameterIndex", String.class);
assertThat(methodSpec.toString()) //
.contains("createQuery(\"{ firstname : ?#{[0]} }\"") //
.contains("Map.of(\"firstname\", firstname)");
.contains("createQuery(\"{ firstname : ?#{[0]} }\", argumentMap(\"firstname\", firstname))");
}
@Test // GH-5006
@ -204,8 +209,7 @@ public class QueryMethodContributionUnitTests { @@ -204,8 +209,7 @@ public class QueryMethodContributionUnitTests {
MethodSpec methodSpec = codeOf(UserRepository.class, "findWithExpressionUsingParameterName", String.class);
assertThat(methodSpec.toString()) //
.contains("createQuery(\"{ firstname : :#{#firstname} }\"") //
.contains("Map.of(\"firstname\", firstname)");
.contains("createQuery(\"{ firstname : :#{#firstname} }\", argumentMap(\"firstname\", firstname))");
}
@Test // GH-4939
@ -214,8 +218,7 @@ public class QueryMethodContributionUnitTests { @@ -214,8 +218,7 @@ public class QueryMethodContributionUnitTests {
MethodSpec methodSpec = codeOf(UserRepository.class, "findByFirstnameRegex", Pattern.class);
assertThat(methodSpec.toString()) //
.contains("createQuery(\"{'firstname':{'$regex':?0}}\"") //
.contains("Object[]{ pattern }");
.contains("createQuery(\"{'firstname':{'$regex':?0}}\", arguments(pattern))");
}
@Test // GH-4939
@ -233,7 +236,7 @@ public class QueryMethodContributionUnitTests { @@ -233,7 +236,7 @@ public class QueryMethodContributionUnitTests {
MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findByFirstname", String.class);
assertThat(methodSpec.toString()) //
.containsPattern(".*\\.collation\\(.*Collation\\.parse\\(\"en_US\"\\)\\)");
.containsSubsequence(".collation(", "Collation.parse(\"en_US\"))");
}
@Test // GH-4939
@ -243,7 +246,122 @@ public class QueryMethodContributionUnitTests { @@ -243,7 +246,122 @@ public class QueryMethodContributionUnitTests {
assertThat(methodSpec.toString()) //
.containsIgnoringWhitespaces(
"collationOf(evaluate(\"?#{[1]}\", java.util.Map.of(\"firstname\", firstname, \"locale\", locale)))");
"collationOf(evaluate(\"?#{[1]}\", argumentMap(\"firstname\", firstname, \"locale\", locale)))");
}
@Test
void rendersVectorSearchFilterFromAnnotatedQuery() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class,
Score.class, Limit.class);
assertThat(methodSpec.toString()) //
.containsSubsequence("$vectorSearch =",
"Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(limit);")
.contains("filter = createQuery(\"{lastname: ?0}\", arguments(lastname, distance))")
.contains("$vectorSearch.filter(filter.getQueryObject())");
}
@Test
void rendersVectorSearchNumCandidatesExpression() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class,
Score.class, Limit.class);
assertThat(methodSpec.toString()) //
.containsSubsequence("$vectorSearch.numCandidates",
"evaluate(\"#{10+10}\", argumentMap(\"lastname\", lastname, \"distance\", distance)))");
}
@Test
void rendersVectorSearchScoringFunctionFromScore() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class,
Score.class, Limit.class);
assertThat(methodSpec.toString()) //
.contains("ScoringFunction scoringFunction = distance.getFunction()");
}
@Test
void rendersVectorSearchSearchTypeFromAnnotation() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class,
Score.class, Limit.class);
assertThat(methodSpec.toString()) //
.containsSubsequence("$vectorSearch.searchType(", "VectorSearchOperation.SearchType.ANN)");
}
@Test
void rendersVectorSearchQueryFromMethodName() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class,
Vector.class, Score.class, Limit.class);
assertThat(methodSpec.toString()) //
.contains("filter = createQuery(\"{'lastname':?0}\", arguments(lastname, similarity))");
}
@Test
void rendersVectorSearchNumCandidatesFromLimitIfNotExplicitlyDefined() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class,
Vector.class, Score.class, Limit.class);
assertThat(methodSpec.toString()) //
.contains("$vectorSearch.numCandidates(limit.max() * 20)");
}
@Test
void rendersVectorSearchLimitFromAnnotation() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "searchByLastnameAndEmbeddingWithin", String.class,
Vector.class, Range.class);
assertThat(methodSpec.toString()) //
.contains("Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(10)")
.contains("$vectorSearch.numCandidates(10 * 20)");
}
@Test
void rendersVectorSearchLimitFromExpression() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepoWithMeta.class,
"searchWithLimitAsExpressionByLastnameAndEmbeddingWithinOrderByFirstname", String.class, Vector.class,
Range.class);
assertThat(methodSpec.toString()) //
.containsSubsequence(
"Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(",
"evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance)")
.containsSubsequence("$vectorSearch.numCandidates(",
"evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance))) * 20)");
}
@Test
void rendersVectorSearchOrderByScoreAsDefault() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class,
Vector.class, Score.class, Limit.class);
assertThat(methodSpec.toString()) //
.contains("$vectorSearch.withSearchScore(\"__score__\")")
.containsSubsequence("$sort = ", "Aggregation.sort(", "DESC, \"__score__\")")
.containsSubsequence("AggregationPipeline(", "List.of($vectorSearch, $sort))");
}
@Test
void rendersVectorSearchOrderByWithScoreLast() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "searchByLastnameAndEmbeddingWithinOrderByFirstname",
String.class, Vector.class, Range.class);
assertThat(methodSpec.toString()) //
.containsSubsequence("AggregationOperation $sort = (_ctx) -> {", //
"_mappedSort = _ctx.getMappedObject(", //
"Document.parse(\"{'firstname':{'$numberInt':'1'}}\")", //
"Document(\"$sort\", _mappedSort.append(\"__score__\", -1))");
}
private static MethodSpec codeOf(Class<?> repository, String methodName, Class<?>... args)
@ -287,5 +405,9 @@ public class QueryMethodContributionUnitTests { @@ -287,5 +405,9 @@ public class QueryMethodContributionUnitTests {
@ReadPreference("NEAREST")
GeoResults<User> findByLocationCoordinatesNear(Point point, Distance maxDistance);
@VectorSearch(indexName = "embedding.vector_cos", limit = "#{5+5}")
SearchResults<User> searchWithLimitAsExpressionByLastnameAndEmbeddingWithinOrderByFirstname(String lastname,
Vector vector, Range<Similarity> distance);
}
}

2
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java

@ -43,6 +43,8 @@ import org.junit.jupiter.api.extension.ExtendWith; @@ -43,6 +43,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
@ExtendWith(MongoServerCondition.class)
public @interface EnableIfVectorSearchAvailable {
String database() default "";
/**
* @return the name of the collection used to run the {@literal $listSearchIndexes} aggregation.
*/

5
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java

@ -117,12 +117,15 @@ class MongoServerCondition implements ExecutionCondition { @@ -117,12 +117,15 @@ class MongoServerCondition implements ExecutionCondition {
? vectorSearchAvailable.collectionName()
: MongoCollectionUtils.getPreferredCollectionName(vectorSearchAvailable.collection());
String databaseName = StringUtils.hasText(vectorSearchAvailable.database())
? vectorSearchAvailable.database() : null;
return context.getStore(NAMESPACE).getOrComputeIfAbsent("search-index-%s-available".formatted(collectionName),
(key) -> {
try {
doWithClient(client -> {
Awaitility.await().atMost(Duration.ofSeconds(60)).pollInterval(Duration.ofMillis(200)).until(() -> {
return MongoTestUtils.isSearchIndexReady(client, null, collectionName);
return MongoTestUtils.isSearchIndexReady(client, databaseName, collectionName);
});
return "done waiting for search index";
});

39
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java

@ -233,4 +233,43 @@ public class MongoTestTemplate extends MongoTemplate { @@ -233,4 +233,43 @@ public class MongoTestTemplate extends MongoTemplate {
public void awaitNoSearchIndexAvailable(Class<?> type, Duration timeout) {
awaitNoSearchIndexAvailable(getCollectionName(type), timeout);
}
public void tryToDropSearchIndexes(Class<?> type, Duration duration) {
tryToDropSearchIndexes(getCollectionName(type), duration);
}
public void tryToDropSearchIndexes(String collectionName, Duration timeout) {
Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> {
try {
this.execute(collectionName, coll -> {
ArrayList<Document> indexDocuments = coll.aggregate(List.of(Document.parse("{'$listSearchIndexes': { } }")))
.into(new ArrayList<>());
indexDocuments.forEach(indexDocument -> {
if (indexDocument.containsKey("name")) {
boolean toBeDeleted = true;
if (indexDocument.containsKey("status")) {
String status = indexDocument.getString("status");
if (status.equals("DELETING") || status.equals("DOES_NOT_EXIST")) {
toBeDeleted = false;
}
}
if (toBeDeleted) {
coll.dropSearchIndex(indexDocument.getString("name"));
}
}
});
return "done with that";
});
} catch (Exception e) {
return false;
}
return true;
});
awaitNoSearchIndexAvailable(collectionName, timeout);
}
}

36
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java

@ -15,8 +15,6 @@ @@ -15,8 +15,6 @@
*/
package org.springframework.data.mongodb.test.util;
import org.jspecify.annotations.Nullable;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.util.retry.Retry;
@ -26,18 +24,22 @@ import java.util.List; @@ -26,18 +24,22 @@ import java.util.List;
import java.util.concurrent.TimeUnit;
import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.springframework.core.env.Environment;
import org.springframework.core.env.StandardEnvironment;
import org.springframework.data.mongodb.SpringDataMongoDB;
import org.springframework.data.util.Version;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.ReadPreference;
import com.mongodb.WriteConcern;
import com.mongodb.client.AggregateIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoCursor;
import com.mongodb.client.MongoDatabase;
import com.mongodb.reactivestreams.client.MongoClients;
@ -292,15 +294,37 @@ public class MongoTestUtils { @@ -292,15 +294,37 @@ public class MongoTestUtils {
}
public static boolean isSearchIndexReady(MongoClient client, @Nullable String database, String collectionName) {
return isSearchIndexReady(null, client, database, collectionName);
}
public static boolean isSearchIndexReady(@Nullable String indexName, MongoClient client, @Nullable String database,
String collectionName) {
try {
MongoCollection<Document> collection = client.getDatabase(StringUtils.hasText(database) ? database : "test").getCollection(collectionName);
collection.aggregate(List.of(new Document("$listSearchIndexes", new Document())));
MongoCollection<Document> collection = client.getDatabase(StringUtils.hasText(database) ? database : "test")
.getCollection(collectionName);
Document filter = StringUtils.hasText(indexName) ? new Document("name", indexName) : new Document();
AggregateIterable<Document> aggregate = collection.aggregate(List.of(new Document("$listSearchIndexes", filter)));
try (MongoCursor<Document> cursor = aggregate.cursor()) {
if (filter.isEmpty()) {
return true;
}
while (cursor.hasNext()) {
Document doc = cursor.next();
if (doc.getString("name").equals(indexName)) {
return doc.getString("status").equals("READY");
}
}
}
return false;
} catch (Exception e) {
return false;
}
return true;
}
public static Duration getTimeout() {

Loading…
Cancel
Save