diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java index 11d6e8bdd..219f90348 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java @@ -15,6 +15,7 @@ */ package org.springframework.data.mongodb.repository.aot; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -46,6 +47,7 @@ import org.springframework.data.mongodb.core.query.CriteriaDefinition.Placeholde import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.TextCriteria; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor; import org.springframework.data.mongodb.repository.query.MongoParameterAccessor; import org.springframework.data.mongodb.repository.query.MongoQueryCreator; @@ -79,14 +81,16 @@ class AotQueryCreator { } @SuppressWarnings("NullAway") - StringQuery createQuery(PartTree partTree, QueryMethod queryMethod) { - + StringQuery createQuery(PartTree partTree, QueryMethod queryMethod, Method source) { boolean geoNear = queryMethod instanceof MongoQueryMethod mqm ? mqm.isGeoNearQuery() : false; + boolean searchQuery = queryMethod instanceof MongoQueryMethod mqm + ? mqm.isSearchQuery() || source.isAnnotationPresent(VectorSearch.class) + : source.isAnnotationPresent(VectorSearch.class); Query query = new MongoQueryCreator(partTree, - new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext, geoNear, queryMethod.isSearchQuery()) - .createQuery(); + new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext, + geoNear, searchQuery).createQuery(); if (partTree.isLimiting()) { query.limit(partTree.getMaxResults()); @@ -141,8 +145,7 @@ class AotQueryCreator { for (Parameter parameter : parameters.toList()) { if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new GeoJsonPlaceholder(parameter.getIndex(), "")); - } - else if (ClassUtils.isAssignable(Point.class, parameter.getType())) { + } else if (ClassUtils.isAssignable(Point.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new PointPlaceholder(parameter.getIndex())); } else if (ClassUtils.isAssignable(Circle.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new CirclePlaceholder(parameter.getIndex())); @@ -152,8 +155,7 @@ class AotQueryCreator { placeholders.add(parameter.getIndex(), new SpherePlaceholder(parameter.getIndex())); } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { placeholders.add(parameter.getIndex(), new PolygonPlaceholder(parameter.getIndex())); - } - else { + } else { placeholders.add(parameter.getIndex(), Placeholder.indexed(parameter.getIndex())); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java index 84de3bb83..86b3217b0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java @@ -16,12 +16,17 @@ package org.springframework.data.mongodb.repository.aot; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.function.Consumer; import org.bson.Document; import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.expression.ValueEvaluationContext; import org.springframework.data.expression.ValueExpression; import org.springframework.data.mapping.model.ValueExpressionEvaluator; @@ -33,6 +38,7 @@ import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.mapping.FieldName; import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.Collation; +import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.repository.query.MongoParameters; import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; @@ -42,7 +48,9 @@ import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; import org.springframework.data.repository.query.ValueExpressionDelegate; import org.springframework.expression.EvaluationContext; +import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; /** @@ -108,7 +116,27 @@ public class MongoAotRepositoryFragmentSupport { return new ParameterBindingDocumentCodec().decode(source, bindingContext); } - protected Object evaluate(String source, Map parameters) { + protected Object[] arguments(Object... arguments) { + return arguments; + } + + protected Map argumentMap(Object... parameters) { + + Assert.state(parameters.length % 2 == 0, "even number of args required"); + + LinkedHashMap argumentMap = CollectionUtils.newLinkedHashMap(parameters.length / 2); + for (int i = 0; i < parameters.length; i += 2) { + + if (!(parameters[i] instanceof String key)) { + throw new IllegalArgumentException("key must be a String"); + } + argumentMap.put(key, parameters[i + 1]); + } + + return argumentMap; + } + + protected @Nullable Object evaluate(String source, Map parameters) { ValueEvaluationContext valueEvaluationContext = this.valueExpressionDelegate.getEvaluationContextAccessor() .create(new NoMongoParameters()).getEvaluationContext(parameters.values()); @@ -120,9 +148,63 @@ public class MongoAotRepositoryFragmentSupport { return parse.evaluate(valueEvaluationContext); } + protected Consumer scoreBetween(Range.Bound lower, Range.Bound upper) { + + return criteria -> { + if (lower.isBounded()) { + double value = lower.getValue().get().getValue(); + if (lower.isInclusive()) { + criteria.gte(value); + } else { + criteria.gt(value); + } + } + + if (upper.isBounded()) { + + double value = upper.getValue().get().getValue(); + if (upper.isInclusive()) { + criteria.lte(value); + } else { + criteria.lt(value); + } + } + + }; + } + + protected ScoringFunction scoringFunction(Range scoreRange) { + + if (scoreRange != null) { + if (scoreRange.getUpperBound().isBounded()) { + return scoreRange.getUpperBound().getValue().get().getFunction(); + } + + if (scoreRange.getLowerBound().isBounded()) { + return scoreRange.getLowerBound().getValue().get().getFunction(); + } + } + + return ScoringFunction.unspecified(); + } + + // Range scoreRange = accessor.getScoreRange(); + // + // if (scoreRange != null) { + // if (scoreRange.getUpperBound().isBounded()) { + // return scoreRange.getUpperBound().getValue().get().getFunction(); + // } + // + // if (scoreRange.getLowerBound().isBounded()) { + // return scoreRange.getLowerBound().getValue().get().getFunction(); + // } + // } + // + // return ScoringFunction.unspecified(); + protected Collation collationOf(@Nullable Object source) { - if(source == null) { + if (source == null) { return Collation.simple(); } if (source instanceof String) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 388199443..4125139bd 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -37,6 +37,7 @@ import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.NumberUtils; import org.springframework.util.StringUtils; /** @@ -49,6 +50,7 @@ class MongoCodeBlocks { private static final Pattern PARAMETER_BINDING_PATTERN = Pattern.compile("\\?(\\d+)"); private static final Pattern EXPRESSION_BINDING_PATTERN = Pattern.compile("[\\?:][#$]\\{.*\\}"); + private static final Pattern VALUE_EXPRESSION_PATTERN = Pattern.compile("^#\\{.*}$"); /** * Builder for generating query parsing {@link CodeBlock}. @@ -179,7 +181,7 @@ class MongoCodeBlocks { } else { builder.add("$T $L = bindParameters($S, ", Document.class, variableName, source); if (containsNamedPlaceholder(source)) { - renderArgumentMap(arguments); + builder.add(renderArgumentMap(arguments)); } else { builder.add(renderArgumentArray(arguments)); } @@ -191,7 +193,7 @@ class MongoCodeBlocks { static CodeBlock renderArgumentMap(Map arguments) { Builder builder = CodeBlock.builder(); - builder.add("$T.of(", Map.class); + builder.add("argumentMap("); Iterator> iterator = arguments.entrySet().iterator(); while (iterator.hasNext()) { Entry next = iterator.next(); @@ -208,24 +210,41 @@ class MongoCodeBlocks { static CodeBlock renderArgumentArray(Map arguments) { Builder builder = CodeBlock.builder(); - builder.add("new $T[]{ ", Object.class); + builder.add("arguments("); Iterator iterator = arguments.values().iterator(); while (iterator.hasNext()) { builder.add(iterator.next()); if (iterator.hasNext()) { builder.add(", "); - } else { - builder.add(" "); } } - builder.add("}"); + builder.add(")"); return builder.build(); } + static CodeBlock evaluateNumberPotentially(String value, Class targetType, + Map arguments) { + try { + Number number = NumberUtils.parseNumber(value, targetType); + return CodeBlock.of("$L", number); + } catch (IllegalArgumentException e) { + + Builder builder = CodeBlock.builder(); + builder.add("($T) evaluate($S, ", targetType, value); + builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); + builder.add(")"); + return builder.build(); + } + } + static boolean containsPlaceholder(String source) { return containsIndexedPlaceholder(source) || containsNamedPlaceholder(source); } + static boolean containsExpression(String source) { + return VALUE_EXPRESSION_PATTERN.matcher(source).find(); + } + static boolean containsNamedPlaceholder(String source) { return EXPRESSION_BINDING_PATTERN.matcher(source).find(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 95b1108f9..524c5e8f2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -37,6 +37,7 @@ import org.springframework.data.mongodb.core.aggregation.AggregationUpdate; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.Update; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotRepositoryClassBuilder; import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder; @@ -107,7 +108,11 @@ public class MongoRepositoryContributor extends RepositoryContributor { } QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod, - AnnotatedElementUtils.findMergedAnnotation(method, Query.class)); + AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method); + + if (queryMethod.isSearchQuery() || method.isAnnotationPresent(VectorSearch.class)) { + return searchMethodContributor(queryMethod, new SearchInteraction(query.getQuery())); + } if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1 && queryMethod.getReturnType().isCollectionLike())) { @@ -126,8 +131,8 @@ public class MongoRepositoryContributor extends RepositoryContributor { UpdateInteraction update = new UpdateInteraction(query, null, updateIndex); return updateMethodContributor(queryMethod, update); - } else { + Update updateSource = queryMethod.getUpdateSource(); if (StringUtils.hasText(updateSource.value())) { UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()), null); @@ -146,7 +151,7 @@ public class MongoRepositoryContributor extends RepositoryContributor { @SuppressWarnings("NullAway") private QueryInteraction createStringQuery(RepositoryInformation repositoryInformation, MongoQueryMethod queryMethod, - @Nullable Query queryAnnotation) { + @Nullable Query queryAnnotation, Method source) { QueryInteraction query; if (queryMethod.hasAnnotatedQuery() && queryAnnotation != null) { @@ -155,8 +160,8 @@ public class MongoRepositoryContributor extends RepositoryContributor { } else { PartTree partTree = new PartTree(queryMethod.getName(), repositoryInformation.getDomainType()); - query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod), partTree.isCountProjection(), - partTree.isDelete(), partTree.isExistsProjection()); + query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod, source), + partTree.isCountProjection(), partTree.isDelete(), partTree.isExistsProjection()); } if (queryAnnotation != null && StringUtils.hasText(queryAnnotation.sort())) { @@ -172,7 +177,7 @@ public class MongoRepositoryContributor extends RepositoryContributor { private static boolean backoff(MongoQueryMethod method) { // TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays. - boolean skip = method.isSearchQuery() || method.getReturnType().getType().isArray(); + boolean skip = method.getReturnType().getType().isArray(); if (skip && logger.isDebugEnabled()) { logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query" @@ -220,6 +225,21 @@ public class MongoRepositoryContributor extends RepositoryContributor { }); } + static MethodContributor searchMethodContributor(MongoQueryMethod queryMethod, + SearchInteraction interaction) { + return MethodContributor.forQueryMethod(queryMethod).withMetadata(interaction).contribute(context -> { + + CodeBlock.Builder builder = CodeBlock.builder(); + + String variableName = "search"; + + builder.add(new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod) + .usingVariableName(variableName).withFilter(interaction.getFilter()).build()); + + return builder.build(); + }); + } + static MethodContributor updateMethodContributor(MongoQueryMethod queryMethod, UpdateInteraction update) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java index e01462547..7ad0c25b1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -211,8 +211,7 @@ class QueryBlocks { Builder builder = CodeBlock.builder(); - builder.add("\n"); - builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName)); + builder.add(buildJustTheQuery()); if (StringUtils.hasText(source.getQuery().getFieldsString())) { @@ -289,6 +288,14 @@ class QueryBlocks { return builder.build(); } + CodeBlock buildJustTheQuery() { + + Builder builder = CodeBlock.builder(); + builder.add("\n"); + builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName)); + return builder.build(); + } + private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) { Builder builder = CodeBlock.builder(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java new file mode 100644 index 000000000..a94ff1082 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java @@ -0,0 +1,48 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.Map; + +import org.jspecify.annotations.Nullable; +import org.springframework.data.repository.aot.generate.QueryMetadata; + +/** + * @author Christoph Strobl + */ +public class SearchInteraction extends MongoInteraction implements QueryMetadata { + + StringQuery filter; + + public SearchInteraction(StringQuery filter) { + this.filter = filter; + } + + public StringQuery getFilter() { + return filter; + } + + @Override + InteractionType getExecutionType() { + return InteractionType.AGGREGATION; + } + + @Override + public Map serialize() { + + return Map.of("FIXME", "please!"); + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java new file mode 100644 index 000000000..3efdc080b --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java @@ -0,0 +1,211 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.lang.reflect.Field; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.bson.Document; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.MongoQueryExecution.VectorSearchExecution; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.util.TypeInformation; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.StringUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class VectorSearchBocks { + + static class VectorSearchQueryCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String searchQueryVariableName; + private StringQuery filter; + private final Map arguments; + + VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + this.arguments = new LinkedHashMap<>(); + context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); + } + + VectorSearchQueryCodeBlockBuilder usingVariableName(String searchQueryVariableName) { + + this.searchQueryVariableName = searchQueryVariableName; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + + String vectorParameterName = context.getVectorParameterName(); + + MergedAnnotation annotation = context.getAnnotation(VectorSearch.class); + String searchPath = annotation.getString("path"); + String indexName = annotation.getString("indexName"); + String numCandidates = annotation.getString("numCandidates"); + SearchType searchType = annotation.getEnum("searchType", SearchType.class); + String limit = annotation.getString("limit"); + + if (!StringUtils.hasText(searchPath)) { // FIXME: somehow duplicate logic of AnnotatedQueryFactory + + Field[] declaredFields = context.getRepositoryInformation().getDomainType().getDeclaredFields(); + for (Field field : declaredFields) { + if (Vector.class.isAssignableFrom(field.getType())) { + searchPath = field.getName(); + break; + } + } + + } + + String vectorSearchVar = context.localVariable("$vectorSearch"); + builder.add("$T $L = $T.vectorSearch($S).path($S).vector($L)", VectorSearchOperation.class, vectorSearchVar, + Aggregation.class, indexName, searchPath, vectorParameterName); + + if (StringUtils.hasText(context.getLimitParameterName())) { + builder.add(".limit($L);\n", context.getLimitParameterName()); + } else if (filter.isLimited()) { + builder.add(".limit($L);\n", filter.getLimit()); + } else if (StringUtils.hasText(limit)) { + if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { + builder.add(".limit("); + builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments)); + builder.add(");\n"); + } else { + builder.add(".limit($L);\n", limit); + } + } else { + builder.add(".limit($T.unlimited());\n", Limit.class); + } + + if (!searchType.equals(SearchType.DEFAULT)) { + builder.addStatement("$1L = $1L.searchType($2T.$3L)", vectorSearchVar, SearchType.class, searchType.name()); + } + + if (StringUtils.hasText(numCandidates)) { + builder.add("$1L = $1L.numCandidates(", vectorSearchVar); + builder.add(MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, arguments)); + builder.add(");\n"); + } else if (searchType == VectorSearchOperation.SearchType.ANN + || searchType == VectorSearchOperation.SearchType.DEFAULT) { + + builder.add( + "// MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return\n"); + if (StringUtils.hasText(context.getLimitParameterName())) { + builder.addStatement("$1L = $1L.numCandidates($2L.max() * 20)", vectorSearchVar, + context.getLimitParameterName()); + } else if (StringUtils.hasText(limit)) { + if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { + + builder.add("$1L = $1L.numCandidates((", vectorSearchVar); + builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments)); + builder.add(") * 20);\n"); + } else { + builder.addStatement("$1L = $1L.numCandidates($2L * 20)", vectorSearchVar, limit); + } + } else { + builder.addStatement("$1L = $1L.numCandidates($2L)", vectorSearchVar, filter.getLimit() * 20); + } + } + + builder.addStatement("$1L = $1L.withSearchScore(\"__score__\")", vectorSearchVar); + if (StringUtils.hasText(context.getScoreParameterName())) { + + String scoreCriteriaVar = context.localVariable("criteria"); + builder.addStatement("$1L = $1L.withFilterBySore($2L -> { $2L.gt($3L.getValue()); })", vectorSearchVar, + scoreCriteriaVar, context.getScoreParameterName()); + } else if (StringUtils.hasText(context.getScoreRangeParameterName())) { + builder.addStatement("$1L = $1L.withFilterBySore(scoreBetween($2L.getLowerBound(), $2L.getUpperBound()))", + vectorSearchVar, context.getScoreRangeParameterName()); + } + + if (StringUtils.hasText(filter.getQueryString())) { + + String filterVar = context.localVariable("filter"); + builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter") + .filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery()); + builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar); + builder.add("\n"); + } + + + String sortStageVar = context.localVariable("$sort"); + if(filter.isSorted()) { + + builder.add("$T $L = (_ctx) -> {\n", AggregationOperation.class, sortStageVar); + builder.indent(); + + builder.addStatement("$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class)", Document.class, filter.getSortString(), context.getActualReturnType().getType()); + builder.addStatement("return new $T($S, _mappedSort.append(\"__score__\", -1))", Document.class, "$sort"); + builder.unindent(); + builder.add("};"); + + } else { + builder.addStatement("var $L = $T.sort($T.Direction.DESC, $S)", sortStageVar, Aggregation.class, Sort.class, "__score__"); + } + builder.add("\n"); + + builder.addStatement("$1T $2L = new $1T($3T.of($4L, $5L))", AggregationPipeline.class, searchQueryVariableName, + List.class, vectorSearchVar, sortStageVar); + + String scoringFunctionVar = context.localVariable("scoringFunction"); + builder.add("$1T $2L = ", ScoringFunction.class, scoringFunctionVar); + if (StringUtils.hasText(context.getScoreParameterName())) { + builder.add("$L.getFunction();\n", context.getScoreParameterName()); + } else if (StringUtils.hasText(context.getScoreRangeParameterName())) { + builder.add("scoringFunction($L);\n", context.getScoreRangeParameterName()); + } else { + builder.add("$1T.unspecified();\n", ScoringFunction.class); + } + + builder.addStatement( + "return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)", + VectorSearchExecution.class, context.fieldNameOf(MongoOperations.class), + context.getRepositoryInformation().getDomainType(), TypeInformation.class, + queryMethod.getReturnType().getType(), searchQueryVariableName, scoringFunctionVar); + return builder.build(); + } + + public VectorSearchQueryCodeBlockBuilder withFilter(StringQuery filter) { + this.filter = filter; + return this; + } + } +} diff --git a/spring-data-mongodb/src/test/java/example/aot/User.java b/spring-data-mongodb/src/test/java/example/aot/User.java index 25514a518..dfe3ec355 100644 --- a/spring-data-mongodb/src/test/java/example/aot/User.java +++ b/spring-data-mongodb/src/test/java/example/aot/User.java @@ -17,6 +17,7 @@ package example.aot; import java.time.Instant; +import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.mapping.Field; /** @@ -38,6 +39,8 @@ public class User { Instant lastSeen; Long visits; + Vector embedding; + public String getId() { return id; } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index e3d04293c..ee1058fdc 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -30,9 +30,13 @@ import org.springframework.data.domain.Limit; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.domain.Window; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; @@ -43,6 +47,7 @@ import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.geo.GeoJson; import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; import org.springframework.data.mongodb.core.geo.Sphere; @@ -51,6 +56,7 @@ import org.springframework.data.mongodb.repository.Hint; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.Update; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.query.Param; @@ -291,6 +297,30 @@ public interface UserRepository extends CrudRepository { "{ '$project': { '_id' : '$last_name' } }" }, collation = "no_collation") List findAllLastnamesWithCollation(); + // Vector Search + + @VectorSearch(indexName = "embedding.vector_cos", filter = "{lastname: ?0}", numCandidates = "#{10+10}", + searchType = VectorSearchOperation.SearchType.ANN) + SearchResults annotatedVectorSearch(String lastname, Vector vector, Score distance, Limit limit); + + @VectorSearch(indexName = "embedding.vector_cos") + SearchResults searchCosineByLastnameAndEmbeddingNear(String lastname, Vector vector, Score similarity, + Limit limit); + + @VectorSearch(indexName = "embedding.vector_cos") + List searchAsListByLastnameAndEmbeddingNear(String lastname, Vector vector, Limit limit); + + @VectorSearch(indexName = "embedding.vector_cos", limit = "10") + SearchResults searchByLastnameAndEmbeddingWithin(String lastname, Vector vector, Range distance); + + @VectorSearch(indexName = "embedding.vector_cos", limit = "10") + SearchResults searchByLastnameAndEmbeddingWithinOrderByFirstname(String lastname, Vector vector, + Range distance); + + @VectorSearch(indexName = "embedding.vector_cos") + SearchResults searchTop1ByLastnameAndEmbeddingWithin(String lastname, Vector vector, + Range distance); + class UserAggregate { @Id // diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java index c76f6bc59..118110e1e 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java @@ -77,9 +77,8 @@ class VectorIndexIntegrationTests { @AfterEach void cleanup() { + template.tryToDropSearchIndexes(Movie.class, Duration.ofSeconds(30)); template.flush(Movie.class); - template.searchIndexOps(Movie.class).dropAllIndexes(); - template.awaitNoSearchIndexAvailable(Movie.class, Duration.ofSeconds(30)); } @ParameterizedTest // GH-4706 diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 93bc878ab..46846c732 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -24,12 +24,15 @@ import example.aot.UserProjection; import example.aot.UserRepository; import example.aot.UserRepository.UserAggregate; +import java.time.Duration; import java.time.Instant; import java.util.List; import java.util.Optional; import java.util.regex.Pattern; +import org.bson.BsonString; import org.bson.Document; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -42,9 +45,13 @@ import org.springframework.data.domain.OffsetScrollPosition; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.domain.Window; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; @@ -60,13 +67,22 @@ import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.geo.GeoJsonPoint; import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; -import org.springframework.data.mongodb.test.util.Client; +import org.springframework.data.mongodb.test.util.AtlasContainer; +import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; import org.springframework.data.mongodb.test.util.MongoTestUtils; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; import org.springframework.util.StringUtils; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.shaded.org.awaitility.Awaitility; import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoCursor; import com.mongodb.client.model.IndexOptions; +import com.mongodb.client.model.SearchIndexModel; +import com.mongodb.client.model.SearchIndexType; /** * Integration tests for the {@link UserRepository} AOT fragment. @@ -74,13 +90,15 @@ import com.mongodb.client.model.IndexOptions; * @author Christoph Strobl * @author Mark Paluch */ - +@Testcontainers(disabledWithoutDocker = true) @SpringJUnitConfig(classes = MongoRepositoryContributorTests.MongoRepositoryContributorConfiguration.class) class MongoRepositoryContributorTests { + private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); private static final String DB_NAME = "aot-repo-tests"; + private static final String COLLECTION_NAME = "user"; - @Client static MongoClient client; + static MongoClient client; @Autowired UserRepository fragment; @Configuration @@ -97,16 +115,63 @@ class MongoRepositoryContributorTests { } @BeforeAll - static void beforeAll() { - client.getDatabase(DB_NAME).getCollection("user").createIndex(new Document("location.coordinates", "2d"), - new IndexOptions()); + static void beforeAll() throws InterruptedException { + + client = MongoClients.create(atlasLocal.getConnectionString()); + MongoCollection userCollection = client.getDatabase(DB_NAME).getCollection(COLLECTION_NAME); + userCollection.createIndex(new Document("location.coordinates", "2d"), new IndexOptions()); + userCollection.createIndex(new Document("location.coordinates", "2dsphere"), new IndexOptions()); + + Thread.sleep(250); // just wait a little or the index will be broken + } + + /** + * Create the vector search index and wait till it is queryable and actually serving data. Since this may slow down + * tests quite a bit, better call it only when needed to run certain tests. + */ + private static void initializeVectorIndex() { + + String indexName = "embedding.vector_cos"; + + Document searchIndex = new Document("fields", + List.of(new Document("type", "vector").append("path", "embedding").append("numDimensions", 5) + .append("similarity", "cosine"), new Document("type", "filter").append("path", "last_name"))); + + MongoCollection userCollection = client.getDatabase(DB_NAME).getCollection(COLLECTION_NAME); + userCollection.createSearchIndexes( + List.of(new SearchIndexModel(indexName, searchIndex, SearchIndexType.of(new BsonString("vectorSearch"))))); + + // wait for search index to be queryable + + Awaitility.await().atMost(Duration.ofSeconds(120)).pollInterval(Duration.ofMillis(200)).until(() -> { + return MongoTestUtils.isSearchIndexReady(indexName, client, DB_NAME, COLLECTION_NAME); + }); + + Document $vectorSearch = new Document("$vectorSearch", + new Document("index", indexName).append("limit", 1).append("numCandidates", 20).append("path", "embedding") + .append("queryVector", List.of(1.0, 1.12345, 2.23456, 3.34567, 4.45678))); + + // wait for search index to serve data + Awaitility.await().atLeast(Duration.ofMillis(50)).atMost(Duration.ofSeconds(120)).ignoreExceptions() + .pollInterval(Duration.ofMillis(250)).until(() -> { + try (MongoCursor cursor = userCollection.aggregate(List.of($vectorSearch)).iterator()) { + if (cursor.hasNext()) { + Document next = cursor.next(); + return true; + } + return false; + } + }); } @BeforeEach - void beforeEach() { + void beforeEach() throws InterruptedException { + initUsers(); + } + @AfterEach + void afterEach() { MongoTestUtils.flushCollection(DB_NAME, "user", client); - initUsers(); } @Test @@ -747,10 +812,86 @@ class MongoRepositoryContributorTests { assertThat(page2.hasNext()).isFalse(); } + @Test + @EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class) + void vectorSearchFromAnnotation() { + + initializeVectorIndex(); + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.annotatedVectorSearch("Skywalker", vector, Score.of(0.99), Limit.of(10)); + + assertThat(results).hasSize(1); + } + + @Test + @EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class) + void vectorSearchWithDerivedQuery() { + + initializeVectorIndex(); + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchCosineByLastnameAndEmbeddingNear("Skywalker", vector, Score.of(0.98), + Limit.of(10)); + + assertThat(results).hasSize(1); + } + + @Test + @EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class) + void vectorSearchReturningResultsAsList() { + + initializeVectorIndex(); + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + List results = fragment.searchAsListByLastnameAndEmbeddingNear("Skywalker", vector, Limit.of(10)); + + assertThat(results).hasSize(2); + } + + @Test + @EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class) + void vectorSearchWithLimitFromAnnotation() { + + initializeVectorIndex(); + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchByLastnameAndEmbeddingWithin("Skywalker", vector, + Similarity.between(0.4, 0.99)); + + assertThat(results).hasSize(1); + } + + @Test + @EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class) + void vectorSearchWithSorting() { + + initializeVectorIndex(); + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchByLastnameAndEmbeddingWithinOrderByFirstname("Skywalker", vector, + Similarity.between(0.4, 1.0)); + + assertThat(results).hasSize(2); + } + + @Test + @EnableIfVectorSearchAvailable(database = DB_NAME, collection = User.class) + void vectorSearchWithLimitFromDerivedQuery() { + + initializeVectorIndex(); + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchTop1ByLastnameAndEmbeddingWithin("Skywalker", vector, + Similarity.between(0.4, 1.0)); + + assertThat(results).hasSize(1); + } + /** * GeoResults results = repository.findPersonByLocationNear(new Point(-73.99, 40.73), range); */ - private static void initUsers() { + private static void initUsers() throws InterruptedException { Document luke = Document.parse(""" { @@ -770,6 +911,7 @@ class MongoRepositoryContributorTests { } } ], + "embedding" : [1.00000, 1.12345, 2.23456, 3.34567, 4.45678], "_class": "example.springdata.aot.User" }"""); @@ -785,6 +927,7 @@ class MongoRepositoryContributorTests { "x" : -73.99171, "y" : 40.738868 } }, + "embedding" : [1.0001, 2.12345, 3.23456, 4.34567, 5.45678], "_class": "example.springdata.aot.User" }"""); @@ -802,6 +945,7 @@ class MongoRepositoryContributorTests { } } ], + "embedding" : [2.0002, 3.12345, 4.23456, 5.34567, 6.45678], "_class": "example.springdata.aot.User" }"""); @@ -812,6 +956,7 @@ class MongoRepositoryContributorTests { "lastSeen" : { "$date": "2025-01-01T00:00:00.000Z" }, + "embedding" : [3.0003, 4.12345, 5.23456, 6.34567, 7.45678], "_class": "example.springdata.aot.User" }"""); @@ -834,7 +979,8 @@ class MongoRepositoryContributorTests { "$date": "2025-01-15T13:53:33.855Z" } } - ] + ], + "embedding" : [4.0004, 5.12345, 6.23456, 7.34567, 8.45678] }"""); Document vader = Document.parse(""" @@ -857,7 +1003,8 @@ class MongoRepositoryContributorTests { "$date": "2025-01-15T13:46:33.855Z" } } - ] + ], + "embedding" : [5.0005, 6.12345, 7.23456, 8.34567, 9.45678] }"""); Document kylo = Document.parse(""" @@ -865,7 +1012,8 @@ class MongoRepositoryContributorTests { "_id": "id-7", "username": "kylo", "first_name": "Ben", - "last_name": "Solo" + "last_name": "Solo", + "embedding" : [6.0006, 7.12345, 8.23456, 9.34567, 10.45678] } """); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index d3d0f40e4..d8de601d4 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -29,8 +29,13 @@ import javax.lang.model.element.Modifier; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.springframework.data.domain.Limit; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; @@ -43,6 +48,7 @@ import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.repository.Hint; import org.springframework.data.mongodb.repository.ReadPreference; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.repository.Repository; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata; @@ -65,7 +71,7 @@ public class QueryMethodContributionUnitTests { assertThat(methodSpec.toString()) // .contains("{'location.coordinates':{'$near':?0}}") // - .contains("Object[]{ location }") // + .contains("arguments(location)") // .contains("return finder.matching(filterQuery).all()"); } @@ -124,7 +130,7 @@ public class QueryMethodContributionUnitTests { assertThat(methodSpec.toString()) // .contains("{'location.coordinates':{'$geoWithin':{'$geometry':?0}}") // - .contains("Object[]{ polygon }") // + .contains("arguments(polygon)") // .contains("return finder.matching(filterQuery).all()"); } @@ -182,7 +188,7 @@ public class QueryMethodContributionUnitTests { assertThat(methodSpec.toString()) // .contains("NearQuery.near(point)") // .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // - .contains("filterQuery = createQuery(\"{'lastname':?0}\", new java.lang.Object[]{ lastname })") // + .contains("filterQuery = createQuery(\"{'lastname':?0}\", arguments(lastname))") // .contains("nearQuery.query(filterQuery)") // .contains(".near(nearQuery)") // .contains("return nearFinder.all()"); @@ -194,8 +200,7 @@ public class QueryMethodContributionUnitTests { MethodSpec methodSpec = codeOf(UserRepository.class, "findWithExpressionUsingParameterIndex", String.class); assertThat(methodSpec.toString()) // - .contains("createQuery(\"{ firstname : ?#{[0]} }\"") // - .contains("Map.of(\"firstname\", firstname)"); + .contains("createQuery(\"{ firstname : ?#{[0]} }\", argumentMap(\"firstname\", firstname))"); } @Test // GH-5006 @@ -204,8 +209,7 @@ public class QueryMethodContributionUnitTests { MethodSpec methodSpec = codeOf(UserRepository.class, "findWithExpressionUsingParameterName", String.class); assertThat(methodSpec.toString()) // - .contains("createQuery(\"{ firstname : :#{#firstname} }\"") // - .contains("Map.of(\"firstname\", firstname)"); + .contains("createQuery(\"{ firstname : :#{#firstname} }\", argumentMap(\"firstname\", firstname))"); } @Test // GH-4939 @@ -214,8 +218,7 @@ public class QueryMethodContributionUnitTests { MethodSpec methodSpec = codeOf(UserRepository.class, "findByFirstnameRegex", Pattern.class); assertThat(methodSpec.toString()) // - .contains("createQuery(\"{'firstname':{'$regex':?0}}\"") // - .contains("Object[]{ pattern }"); + .contains("createQuery(\"{'firstname':{'$regex':?0}}\", arguments(pattern))"); } @Test // GH-4939 @@ -233,7 +236,7 @@ public class QueryMethodContributionUnitTests { MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findByFirstname", String.class); assertThat(methodSpec.toString()) // - .containsPattern(".*\\.collation\\(.*Collation\\.parse\\(\"en_US\"\\)\\)"); + .containsSubsequence(".collation(", "Collation.parse(\"en_US\"))"); } @Test // GH-4939 @@ -243,7 +246,122 @@ public class QueryMethodContributionUnitTests { assertThat(methodSpec.toString()) // .containsIgnoringWhitespaces( - "collationOf(evaluate(\"?#{[1]}\", java.util.Map.of(\"firstname\", firstname, \"locale\", locale)))"); + "collationOf(evaluate(\"?#{[1]}\", argumentMap(\"firstname\", firstname, \"locale\", locale)))"); + } + + @Test + void rendersVectorSearchFilterFromAnnotatedQuery() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("$vectorSearch =", + "Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(limit);") + .contains("filter = createQuery(\"{lastname: ?0}\", arguments(lastname, distance))") + .contains("$vectorSearch.filter(filter.getQueryObject())"); + } + + @Test + void rendersVectorSearchNumCandidatesExpression() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("$vectorSearch.numCandidates", + "evaluate(\"#{10+10}\", argumentMap(\"lastname\", lastname, \"distance\", distance)))"); + } + + @Test + void rendersVectorSearchScoringFunctionFromScore() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("ScoringFunction scoringFunction = distance.getFunction()"); + } + + @Test + void rendersVectorSearchSearchTypeFromAnnotation() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("$vectorSearch.searchType(", "VectorSearchOperation.SearchType.ANN)"); + } + + @Test + void rendersVectorSearchQueryFromMethodName() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class, + Vector.class, Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("filter = createQuery(\"{'lastname':?0}\", arguments(lastname, similarity))"); + } + + @Test + void rendersVectorSearchNumCandidatesFromLimitIfNotExplicitlyDefined() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class, + Vector.class, Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("$vectorSearch.numCandidates(limit.max() * 20)"); + } + + @Test + void rendersVectorSearchLimitFromAnnotation() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchByLastnameAndEmbeddingWithin", String.class, + Vector.class, Range.class); + + assertThat(methodSpec.toString()) // + .contains("Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(10)") + .contains("$vectorSearch.numCandidates(10 * 20)"); + } + + @Test + void rendersVectorSearchLimitFromExpression() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, + "searchWithLimitAsExpressionByLastnameAndEmbeddingWithinOrderByFirstname", String.class, Vector.class, + Range.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence( + "Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(", + "evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance)") + .containsSubsequence("$vectorSearch.numCandidates(", + "evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance))) * 20)"); + } + + @Test + void rendersVectorSearchOrderByScoreAsDefault() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class, + Vector.class, Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("$vectorSearch.withSearchScore(\"__score__\")") + .containsSubsequence("$sort = ", "Aggregation.sort(", "DESC, \"__score__\")") + .containsSubsequence("AggregationPipeline(", "List.of($vectorSearch, $sort))"); + } + + @Test + void rendersVectorSearchOrderByWithScoreLast() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchByLastnameAndEmbeddingWithinOrderByFirstname", + String.class, Vector.class, Range.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("AggregationOperation $sort = (_ctx) -> {", // + "_mappedSort = _ctx.getMappedObject(", // + "Document.parse(\"{'firstname':{'$numberInt':'1'}}\")", // + "Document(\"$sort\", _mappedSort.append(\"__score__\", -1))"); } private static MethodSpec codeOf(Class repository, String methodName, Class... args) @@ -287,5 +405,9 @@ public class QueryMethodContributionUnitTests { @ReadPreference("NEAREST") GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); + + @VectorSearch(indexName = "embedding.vector_cos", limit = "#{5+5}") + SearchResults searchWithLimitAsExpressionByLastnameAndEmbeddingWithinOrderByFirstname(String lastname, + Vector vector, Range distance); } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java index 7570d6fc5..69e1f8da9 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java @@ -43,6 +43,8 @@ import org.junit.jupiter.api.extension.ExtendWith; @ExtendWith(MongoServerCondition.class) public @interface EnableIfVectorSearchAvailable { + String database() default ""; + /** * @return the name of the collection used to run the {@literal $listSearchIndexes} aggregation. */ diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java index a1536d01d..d1ae4e09b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java @@ -117,12 +117,15 @@ class MongoServerCondition implements ExecutionCondition { ? vectorSearchAvailable.collectionName() : MongoCollectionUtils.getPreferredCollectionName(vectorSearchAvailable.collection()); + String databaseName = StringUtils.hasText(vectorSearchAvailable.database()) + ? vectorSearchAvailable.database() : null; + return context.getStore(NAMESPACE).getOrComputeIfAbsent("search-index-%s-available".formatted(collectionName), (key) -> { try { doWithClient(client -> { Awaitility.await().atMost(Duration.ofSeconds(60)).pollInterval(Duration.ofMillis(200)).until(() -> { - return MongoTestUtils.isSearchIndexReady(client, null, collectionName); + return MongoTestUtils.isSearchIndexReady(client, databaseName, collectionName); }); return "done waiting for search index"; }); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java index 4e619c609..69e8ae5e8 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java @@ -233,4 +233,43 @@ public class MongoTestTemplate extends MongoTemplate { public void awaitNoSearchIndexAvailable(Class type, Duration timeout) { awaitNoSearchIndexAvailable(getCollectionName(type), timeout); } + + public void tryToDropSearchIndexes(Class type, Duration duration) { + tryToDropSearchIndexes(getCollectionName(type), duration); + } + + public void tryToDropSearchIndexes(String collectionName, Duration timeout) { + + Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> { + + try { + this.execute(collectionName, coll -> { + + ArrayList indexDocuments = coll.aggregate(List.of(Document.parse("{'$listSearchIndexes': { } }"))) + .into(new ArrayList<>()); + + indexDocuments.forEach(indexDocument -> { + if (indexDocument.containsKey("name")) { + boolean toBeDeleted = true; + if (indexDocument.containsKey("status")) { + String status = indexDocument.getString("status"); + if (status.equals("DELETING") || status.equals("DOES_NOT_EXIST")) { + toBeDeleted = false; + } + } + if (toBeDeleted) { + coll.dropSearchIndex(indexDocument.getString("name")); + } + } + }); + return "done with that"; + }); + } catch (Exception e) { + return false; + } + return true; + }); + + awaitNoSearchIndexAvailable(collectionName, timeout); + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java index 742fd5b44..81634b475 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java @@ -15,8 +15,6 @@ */ package org.springframework.data.mongodb.test.util; -import org.jspecify.annotations.Nullable; -import org.springframework.util.StringUtils; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import reactor.util.retry.Retry; @@ -26,18 +24,22 @@ import java.util.List; import java.util.concurrent.TimeUnit; import org.bson.Document; +import org.jspecify.annotations.Nullable; import org.springframework.core.env.Environment; import org.springframework.core.env.StandardEnvironment; import org.springframework.data.mongodb.SpringDataMongoDB; import org.springframework.data.util.Version; import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; import com.mongodb.ConnectionString; import com.mongodb.MongoClientSettings; import com.mongodb.ReadPreference; import com.mongodb.WriteConcern; +import com.mongodb.client.AggregateIterable; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoCursor; import com.mongodb.client.MongoDatabase; import com.mongodb.reactivestreams.client.MongoClients; @@ -292,15 +294,37 @@ public class MongoTestUtils { } public static boolean isSearchIndexReady(MongoClient client, @Nullable String database, String collectionName) { + return isSearchIndexReady(null, client, database, collectionName); + } + + public static boolean isSearchIndexReady(@Nullable String indexName, MongoClient client, @Nullable String database, + String collectionName) { try { - MongoCollection collection = client.getDatabase(StringUtils.hasText(database) ? database : "test").getCollection(collectionName); - collection.aggregate(List.of(new Document("$listSearchIndexes", new Document()))); + MongoCollection collection = client.getDatabase(StringUtils.hasText(database) ? database : "test") + .getCollection(collectionName); + + Document filter = StringUtils.hasText(indexName) ? new Document("name", indexName) : new Document(); + AggregateIterable aggregate = collection.aggregate(List.of(new Document("$listSearchIndexes", filter))); + + try (MongoCursor cursor = aggregate.cursor()) { + + if (filter.isEmpty()) { + return true; + } + + while (cursor.hasNext()) { + Document doc = cursor.next(); + if (doc.getString("name").equals(indexName)) { + return doc.getString("status").equals("READY"); + } + } + } + + return false; } catch (Exception e) { return false; } - return true; - } public static Duration getTimeout() {