Browse Source

Polishing.

Original Pull Request: #4960
pull/4976/head
Christoph Strobl 8 months ago committed by Mark Paluch
parent
commit
21568c84eb
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 27
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java
  2. 1
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java
  3. 14
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java
  4. 1
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java
  5. 6
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java
  6. 9
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java
  7. 41
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java
  8. 63
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java
  9. 20
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java
  10. 8
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java
  11. 10
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java
  12. 250
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java
  13. 3
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java
  14. 3
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java
  15. 156
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java
  16. 2
      src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc
  17. 8
      src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc
  18. 12
      src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc
  19. 12
      src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc
  20. 6
      src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc

27
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

@ -486,7 +486,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, @@ -486,7 +486,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
return doStream(query, entityType, collectionName, returnType, QueryResultConverter.entity());
}
@SuppressWarnings("ConstantConditions")
@SuppressWarnings({"ConstantConditions", "NullAway"})
<T, R> Stream<R> doStream(Query query, Class<?> entityType, String collectionName, Class<T> returnType,
QueryResultConverter<? super T, ? extends R> resultConverter) {
@ -1086,34 +1086,29 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, @@ -1086,34 +1086,29 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
return new GeoResults<>(result, avgDistance);
}
@Nullable
@Override
public <T> T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass) {
public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass) {
return findAndModify(query, update, new FindAndModifyOptions(), entityClass, getCollectionName(entityClass));
}
@Nullable
@Override
public <T> T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass,
public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass,
String collectionName) {
return findAndModify(query, update, new FindAndModifyOptions(), entityClass, collectionName);
}
@Nullable
@Override
public <T> T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
Class<T> entityClass) {
return findAndModify(query, update, options, entityClass, getCollectionName(entityClass));
}
@Nullable
@Override
public <T> T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
Class<T> entityClass, String collectionName) {
return findAndModify(query, update, options, entityClass, collectionName, QueryResultConverter.entity());
}
<S, T> T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
<S, T> @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
Class<S> entityClass, String collectionName, QueryResultConverter<? super S, ? extends T> resultConverter) {
Assert.notNull(query, "Query must not be null");
@ -1185,15 +1180,13 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, @@ -1185,15 +1180,13 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
// Find methods that take a Query to express the query and that return a single object that is also removed from the
// collection in the database.
@Nullable
@Override
public <T> T findAndRemove(Query query, Class<T> entityClass) {
public <T> @Nullable T findAndRemove(Query query, Class<T> entityClass) {
return findAndRemove(query, entityClass, getCollectionName(entityClass));
}
@Nullable
@Override
public <T> T findAndRemove(Query query, Class<T> entityClass, String collectionName) {
public <T> @Nullable T findAndRemove(Query query, Class<T> entityClass, String collectionName) {
Assert.notNull(query, "Query must not be null");
Assert.notNull(entityClass, "EntityClass must not be null");
@ -2161,11 +2154,11 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, @@ -2161,11 +2154,11 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
* @param entityClass
* @return
*/
@SuppressWarnings("NullAway")
protected <T> List<T> doFindAndDelete(String collectionName, Query query, Class<T> entityClass) {
return doFindAndDelete(collectionName, query, entityClass, QueryResultConverter.entity());
}
@SuppressWarnings("NullAway")
<S, T> List<T> doFindAndDelete(String collectionName, Query query, Class<S> entityClass,
QueryResultConverter<? super S, ? extends T> resultConverter) {
@ -2229,7 +2222,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, @@ -2229,7 +2222,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
return doAggregate(aggregation, collectionName, outputType, QueryResultConverter.entity(), context);
}
@SuppressWarnings("ConstantConditions")
@SuppressWarnings({"ConstantConditions", "NullAway"})
<T, O> AggregationResults<O> doAggregate(Aggregation aggregation, String collectionName, Class<T> outputType,
QueryResultConverter<? super T, ? extends O> resultConverter, AggregationOperationContext context) {

1
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java

@ -2293,6 +2293,7 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati @@ -2293,6 +2293,7 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati
.flatMapSequential(deleteResult -> Flux.fromIterable(list)));
}
@SuppressWarnings({"rawtypes", "unchecked", "NullAway"})
<S, T> Flux<T> doFindAndDelete(String collectionName, Query query, Class<S> entityClass,
QueryResultConverter<? super S, ? extends T> resultConverter) {

14
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java

@ -22,8 +22,10 @@ import java.util.List; @@ -22,8 +22,10 @@ import java.util.List;
import java.util.function.Predicate;
import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.springframework.lang.Contract;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/**
* The {@link AggregationPipeline} holds the collection of {@link AggregationOperation aggregation stages}.
@ -82,6 +84,14 @@ public class AggregationPipeline { @@ -82,6 +84,14 @@ public class AggregationPipeline {
return Collections.unmodifiableList(pipeline);
}
public @Nullable AggregationOperation firstOperation() {
return CollectionUtils.firstElement(pipeline);
}
public @Nullable AggregationOperation lastOperation() {
return CollectionUtils.lastElement(pipeline);
}
List<Document> toDocuments(AggregationOperationContext context) {
verify();
@ -97,8 +107,8 @@ public class AggregationPipeline { @@ -97,8 +107,8 @@ public class AggregationPipeline {
return false;
}
AggregationOperation operation = pipeline.get(pipeline.size() - 1);
return isOut(operation) || isMerge(operation);
AggregationOperation operation = lastOperation();
return operation != null && (isOut(operation) || isMerge(operation));
}
void verify() {

1
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java

@ -356,6 +356,7 @@ public class ArrayOperators { @@ -356,6 +356,7 @@ public class ArrayOperators {
* @return new instance of {@link SortArray}.
* @since 4.5
*/
@SuppressWarnings("NullAway")
public SortArray sort(Direction direction) {
if (usesFieldRef()) {

6
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java

@ -77,7 +77,7 @@ public class ConvertingParameterAccessor implements MongoParameterAccessor { @@ -77,7 +77,7 @@ public class ConvertingParameterAccessor implements MongoParameterAccessor {
}
@Override
public Vector getVector() {
public @Nullable Vector getVector() {
return delegate.getVector();
}
@ -104,12 +104,12 @@ public class ConvertingParameterAccessor implements MongoParameterAccessor { @@ -104,12 +104,12 @@ public class ConvertingParameterAccessor implements MongoParameterAccessor {
}
@Override
public @org.jspecify.annotations.Nullable Score getScore() {
public @Nullable Score getScore() {
return delegate.getScore();
}
@Override
public @org.jspecify.annotations.Nullable Range<Score> getScoreRange() {
public @Nullable Range<Score> getScoreRange() {
return delegate.getScoreRange();
}

9
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java

@ -61,14 +61,13 @@ public class MongoParametersParameterAccessor extends ParametersParameterAccesso @@ -61,14 +61,13 @@ public class MongoParametersParameterAccessor extends ParametersParameterAccesso
public Range<Score> getScoreRange() {
MongoParameters mongoParameters = method.getParameters();
int rangeIndex = mongoParameters.getScoreRangeIndex();
if (rangeIndex != -1) {
return getValue(rangeIndex);
if (mongoParameters.hasScoreRangeParameter()) {
return getValue(mongoParameters.getScoreRangeIndex());
}
int scoreIndex = mongoParameters.getScoreIndex();
Bound<Score> maxDistance = scoreIndex == -1 ? Bound.unbounded() : Bound.inclusive((Score) getScore());
Score score = getScore();
Bound<Score> maxDistance = score != null ? Bound.inclusive(score) : Bound.unbounded();
return Range.of(Bound.unbounded(), maxDistance);
}

41
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java

@ -15,7 +15,8 @@ @@ -15,7 +15,8 @@
*/
package org.springframework.data.mongodb.repository.query;
import static org.springframework.data.mongodb.core.query.Criteria.*;
import static org.springframework.data.mongodb.core.query.Criteria.Placeholder;
import static org.springframework.data.mongodb.core.query.Criteria.where;
import java.util.Arrays;
import java.util.Collection;
@ -27,7 +28,6 @@ import org.apache.commons.logging.Log; @@ -27,7 +28,6 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.bson.BsonRegularExpression;
import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.Range.Bound;
import org.springframework.data.domain.Sort;
@ -118,8 +118,9 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> { @@ -118,8 +118,9 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
return new Criteria();
}
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) {
return null;
if (isPartOfSearchQuery(part)) {
skip(part, iterator);
return new Criteria();
}
PersistentPropertyPath<MongoPersistentProperty> path = context.getPersistentPropertyPath(part.getProperty());
@ -135,7 +136,8 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> { @@ -135,7 +136,8 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
return create(part, iterator);
}
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) {
if (isPartOfSearchQuery(part)) {
skip(part, iterator);
return base;
}
@ -176,15 +178,6 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> { @@ -176,15 +178,6 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
@SuppressWarnings("NullAway")
private Criteria from(Part part, MongoPersistentProperty property, Criteria criteria, Iterator<Object> parameters) {
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) {
int numberOfArguments = part.getType().getNumberOfArguments();
for (int i = 0; i < numberOfArguments; i++) {
parameters.next();
}
return null;
}
Type type = part.getType();
switch (type) {
@ -433,8 +426,7 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> { @@ -433,8 +426,7 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
streamable = streamable.map(it -> {
if (it instanceof String sv) {
return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode),
regexOptions);
return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), regexOptions);
}
return it;
});
@ -468,10 +460,23 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> { @@ -468,10 +460,23 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
return false;
}
private boolean isPartOfSearchQuery(Part part) {
return isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN));
}
private static void skip(Part part, Iterator<?> parameters) {
int total = part.getNumberOfArguments();
int i = 0;
while (parameters.hasNext() && i < total) {
parameters.next();
i++;
}
}
/**
* Compute a {@link Type#BETWEEN} typed {@link Part} using {@link Criteria#gt(Object) $gt},
* {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}.
* <br />
* {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}. <br />
* In case the first {@literal value} is actually a {@link Range} the lower and upper bounds of the {@link Range} are
* used according to their {@link Bound#isInclusive() inclusion} definition. Otherwise the {@literal value} is used
* for {@literal $gt} and {@link Iterator#next() parameters.next()} as {@literal $lt}.

63
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java

@ -15,18 +15,16 @@ @@ -15,18 +15,16 @@
*/
package org.springframework.data.mongodb.repository.query;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.function.Supplier;
import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.SearchResult;
import org.springframework.data.domain.SearchResults;
import org.springframework.data.domain.Similarity;
@ -37,6 +35,7 @@ import org.springframework.data.geo.GeoPage; @@ -37,6 +35,7 @@ import org.springframework.data.geo.GeoPage;
import org.springframework.data.geo.GeoResult;
import org.springframework.data.geo.GeoResults;
import org.springframework.data.geo.Point;
import org.springframework.data.mongodb.core.ExecutableAggregationOperation.TerminatingAggregation;
import org.springframework.data.mongodb.core.ExecutableFindOperation;
import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery;
import org.springframework.data.mongodb.core.ExecutableFindOperation.TerminatingFind;
@ -45,12 +44,13 @@ import org.springframework.data.mongodb.core.ExecutableRemoveOperation.Executabl @@ -45,12 +44,13 @@ import org.springframework.data.mongodb.core.ExecutableRemoveOperation.Executabl
import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove;
import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.query.NearQuery;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.mongodb.repository.util.SliceUtils;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.support.PageableExecutionUtils;
@ -225,52 +225,53 @@ public interface MongoQueryExecution { @@ -225,52 +225,53 @@ public interface MongoQueryExecution {
* {@link MongoQueryExecution} to execute vector search.
*
* @author Mark Paluch
* @author Chistoph Strobl
* @since 5.0
*/
class VectorSearchExecution implements MongoQueryExecution {
private final MongoOperations operations;
private final MongoQueryMethod method;
private final TypeInformation<?> returnType;
private final String collectionName;
private final VectorSearchDelegate.QueryMetadata queryMetadata;
private final List<AggregationOperation> pipeline;
private final Class<?> targetType;
private final ScoringFunction scoringFunction;
private final AggregationPipeline pipeline;
VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName,
QueryContainer queryContainer) {
this(operations, queryContainer.outputType(), collectionName, method.getReturnType(), queryContainer.pipeline(),
queryContainer.scoringFunction());
}
public VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName,
VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) {
public VectorSearchExecution(MongoOperations operations, Class<?> targetType, String collectionName,
TypeInformation<?> returnType, AggregationPipeline pipeline, ScoringFunction scoringFunction) {
this.operations = operations;
this.returnType = returnType;
this.collectionName = collectionName;
this.queryMetadata = queryMetadata;
this.method = method;
this.pipeline = queryMetadata.getAggregationPipeline(method, accessor);
this.targetType = targetType;
this.scoringFunction = scoringFunction;
this.pipeline = pipeline;
}
@Override
@SuppressWarnings({ "unchecked", "rawtypes" })
public Object execute(Query query) {
AggregationResults<?> aggregated = operations.aggregate(
TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collectionName,
queryMetadata.outputType());
TerminatingAggregation<?> executableAggregation = operations.aggregateAndReturn(targetType)
.inCollection(collectionName).by(TypedAggregation.newAggregation(targetType, pipeline.getOperations()));
List<?> mappedResults = aggregated.getMappedResults();
if (isSearchResult(method.getReturnType())) {
List<org.bson.Document> rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class);
List<SearchResult<Object>> result = new ArrayList<>(mappedResults.size());
for (int i = 0; i < mappedResults.size(); i++) {
Document document = rawResults.get(i);
SearchResult<Object> searchResult = new SearchResult<>(mappedResults.get(i),
Similarity.raw(document.getDouble("__score__"), queryMetadata.scoringFunction()));
result.add(searchResult);
if (!isSearchResult(returnType)) {
return executableAggregation.all().getMappedResults();
}
return isListOfSearchResult(method.getReturnType()) ? result : new SearchResults<>(result);
}
AggregationResults<? extends SearchResult<?>> result = executableAggregation
.map((raw, container) -> new SearchResult<>(container.get(),
Similarity.raw(raw.getDouble("__score__"), scoringFunction)))
.all();
return mappedResults;
return isListOfSearchResult(returnType) ? result.getMappedResults()
: new SearchResults(result.getMappedResults());
}
private static boolean isListOfSearchResult(TypeInformation<?> returnType) {

20
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java

@ -18,12 +18,9 @@ package org.springframework.data.mongodb.repository.query; @@ -18,12 +18,9 @@ package org.springframework.data.mongodb.repository.query;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.List;
import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.reactivestreams.Publisher;
import org.springframework.core.convert.converter.Converter;
import org.springframework.data.convert.DtoInstantiatingConverter;
import org.springframework.data.domain.Pageable;
@ -36,11 +33,12 @@ import org.springframework.data.geo.Point; @@ -36,11 +33,12 @@ import org.springframework.data.geo.Point;
import org.springframework.data.mapping.model.EntityInstantiators;
import org.springframework.data.mongodb.core.ReactiveMongoOperations;
import org.springframework.data.mongodb.core.ReactiveUpdateOperation.ReactiveUpdate;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.query.NearQuery;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.repository.query.ReturnedType;
import org.springframework.data.util.ReactiveWrappers;
@ -134,24 +132,24 @@ interface ReactiveMongoQueryExecution { @@ -134,24 +132,24 @@ interface ReactiveMongoQueryExecution {
class VectorSearchExecution implements ReactiveMongoQueryExecution {
private final ReactiveMongoOperations operations;
private final VectorSearchDelegate.QueryMetadata queryMetadata;
private final List<AggregationOperation> pipeline;
private final QueryContainer queryMetadata;
private final AggregationPipeline pipeline;
private final boolean returnSearchResult;
public VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method,
VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) {
VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, QueryContainer queryMetadata) {
this.operations = operations;
this.queryMetadata = queryMetadata;
this.pipeline = queryMetadata.getAggregationPipeline(method, accessor);
this.pipeline = queryMetadata.pipeline();
this.returnSearchResult = isSearchResult(method.getReturnType());
}
@Override
public Publisher<? extends Object> execute(Query query, Class<?> type, String collection) {
Flux<Document> aggregate = operations
.aggregate(TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collection, Document.class);
Flux<Document> aggregate = operations.aggregate(
TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline.getOperations()), collection,
Document.class);
return aggregate.map(document -> {

8
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java

@ -19,13 +19,13 @@ import reactor.core.publisher.Mono; @@ -19,13 +19,13 @@ import reactor.core.publisher.Mono;
import org.bson.Document;
import org.reactivestreams.Publisher;
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.ReactiveMongoOperations;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.repository.query.ValueExpressionDelegate;
@ -84,11 +84,11 @@ public class ReactiveVectorSearchAggregation extends AbstractReactiveMongoQuery @@ -84,11 +84,11 @@ public class ReactiveVectorSearchAggregation extends AbstractReactiveMongoQuery
ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue,
expressionEvaluator);
VectorSearchDelegate.QueryMetadata query = delegate.createQuery(expressionEvaluator, processor, accessor,
typeToRead, codec, bindingContext);
QueryContainer query = delegate.createQuery(expressionEvaluator, processor, accessor, typeToRead, codec,
bindingContext);
ReactiveMongoQueryExecution.VectorSearchExecution execution = new ReactiveMongoQueryExecution.VectorSearchExecution(
mongoOperations, method, query, accessor);
mongoOperations, method, query);
return execution.execute(query.query(), Document.class, collectionEntity.getCollection());
});

10
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java

@ -15,16 +15,17 @@ @@ -15,16 +15,17 @@
*/
package org.springframework.data.mongodb.repository.query;
import org.jspecify.annotations.Nullable;
import org.springframework.data.mapping.model.ValueExpressionEvaluator;
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.repository.query.ValueExpressionDelegate;
import org.springframework.lang.Nullable;
/**
* {@link AbstractMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either derived
@ -62,20 +63,19 @@ public class VectorSearchAggregation extends AbstractMongoQuery { @@ -62,20 +63,19 @@ public class VectorSearchAggregation extends AbstractMongoQuery {
this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate);
}
@SuppressWarnings("unchecked")
@Override
protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor,
@Nullable Class<?> typeToRead) {
VectorSearchDelegate.QueryMetadata query = createVectorSearchQuery(processor, accessor, typeToRead);
QueryContainer query = createVectorSearchQuery(processor, accessor, typeToRead);
MongoQueryExecution.VectorSearchExecution execution = new MongoQueryExecution.VectorSearchExecution(mongoOperations,
method, collectionEntity.getCollection(), query, accessor);
method, collectionEntity.getCollection(), query);
return execution.execute(query.query());
}
VectorSearchDelegate.QueryMetadata createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor,
QueryContainer createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor,
@Nullable Class<?> typeToRead) {
ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor);

250
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java

@ -20,7 +20,6 @@ import java.util.List; @@ -20,7 +20,6 @@ import java.util.List;
import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.Score;
@ -34,6 +33,7 @@ import org.springframework.data.mapping.model.ValueExpressionEvaluator; @@ -34,6 +33,7 @@ import org.springframework.data.mapping.model.ValueExpressionEvaluator;
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
@ -47,6 +47,7 @@ import org.springframework.data.repository.query.ResultProcessor; @@ -47,6 +47,7 @@ import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.repository.query.ValueExpressionDelegate;
import org.springframework.data.repository.query.parser.Part;
import org.springframework.data.repository.query.parser.PartTree;
import org.springframework.util.NumberUtils;
import org.springframework.util.StringUtils;
/**
@ -58,32 +59,35 @@ class VectorSearchDelegate { @@ -58,32 +59,35 @@ class VectorSearchDelegate {
private final VectorSearchQueryFactory queryFactory;
private final VectorSearchOperation.SearchType searchType;
private final String indexName;
private final @Nullable Integer numCandidates;
private final @Nullable String numCandidatesExpression;
private final Limit limit;
private final @Nullable String limitExpression;
private final MongoConverter converter;
public VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) {
VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) {
VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow();
this.searchType = vectorSearch.searchType();
this.indexName = method.getAnnotatedHint();
if (StringUtils.hasText(vectorSearch.numCandidates())) {
ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates());
if (expression.isLiteral()) {
numCandidates = Integer.parseInt(vectorSearch.numCandidates());
numCandidatesExpression = null;
this.numCandidates = Integer.parseInt(vectorSearch.numCandidates());
this.numCandidatesExpression = null;
} else {
numCandidates = null;
numCandidatesExpression = vectorSearch.numCandidates();
this.numCandidates = null;
this.numCandidatesExpression = vectorSearch.numCandidates();
}
} else {
numCandidates = null;
numCandidatesExpression = null;
this.numCandidates = null;
this.numCandidatesExpression = null;
}
if (StringUtils.hasText(vectorSearch.limit())) {
@ -91,26 +95,26 @@ class VectorSearchDelegate { @@ -91,26 +95,26 @@ class VectorSearchDelegate {
ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit());
if (expression.isLiteral()) {
limit = Limit.of(Integer.parseInt(vectorSearch.limit()));
limitExpression = null;
this.limit = Limit.of(Integer.parseInt(vectorSearch.limit()));
this.limitExpression = null;
} else {
limit = Limit.unlimited();
limitExpression = vectorSearch.limit();
this.limit = Limit.unlimited();
this.limitExpression = vectorSearch.limit();
}
} else {
limit = Limit.unlimited();
limitExpression = null;
this.limit = Limit.unlimited();
this.limitExpression = null;
}
this.converter = converter;
if (StringUtils.hasText(vectorSearch.filter())) {
queryFactory = StringUtils.hasText(vectorSearch.path())
this.queryFactory = StringUtils.hasText(vectorSearch.path())
? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path())
: new AnnotatedQueryFactory(vectorSearch.filter(), method.getEntityInformation().getCollectionEntity());
} else {
queryFactory = new PartTreeQueryFactory(
this.queryFactory = new PartTreeQueryFactory(
new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()),
converter.getMappingContext());
}
@ -119,43 +123,136 @@ class VectorSearchDelegate { @@ -119,43 +123,136 @@ class VectorSearchDelegate {
/**
* Create Query Metadata for {@code $vectorSearch}.
*/
public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor,
QueryContainer createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor,
MongoParameterAccessor accessor, @Nullable Class<?> typeToRead, ParameterBindingDocumentCodec codec,
ParameterBindingContext context) {
Integer numCandidates = null;
Limit limit;
String scoreField = "__score__";
Class<?> outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType();
VectorSearchInput query = queryFactory.createQuery(accessor, codec, context);
VectorSearchInput vectorSearchInput = createSearchInput(evaluator, accessor, codec, context);
AggregationPipeline pipeline = createVectorSearchPipeline(vectorSearchInput, scoreField, outputType, accessor,
evaluator);
if (this.limitExpression != null) {
Object value = evaluator.evaluate(this.limitExpression);
limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue());
} else if (this.limit.isLimited()) {
limit = this.limit;
} else {
limit = accessor.getLimit();
return new QueryContainer(vectorSearchInput.path, scoreField, vectorSearchInput.query, pipeline, searchType,
outputType, getSimilarityFunction(accessor), indexName);
}
if (limit.isLimited()) {
query.query().limit(limit);
}
@SuppressWarnings("NullAway")
AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class<?> outputType,
MongoParameterAccessor accessor, ValueExpressionEvaluator evaluator) {
Vector vector = accessor.getVector();
Score score = accessor.getScore();
Range<Score> distance = accessor.getScoreRange();
Limit limit = Limit.of(input.query().getLimit());
List<AggregationOperation> stages = new ArrayList<>();
VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(indexName).path(input.path()).vector(vector)
.limit(limit);
Integer candidates = null;
if (this.numCandidatesExpression != null) {
numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue();
candidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue();
} else if (this.numCandidates != null) {
numCandidates = this.numCandidates;
} else if (query.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN
candidates = this.numCandidates;
} else if (input.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN
|| searchType == VectorSearchOperation.SearchType.DEFAULT)) {
/*
MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy.
*/
numCandidates = query.query().getLimit() * 20;
candidates = input.query().getLimit() * 20;
}
if (candidates != null) {
$vectorSearch = $vectorSearch.numCandidates(candidates);
}
//
$vectorSearch = $vectorSearch.filter(input.query.getQueryObject());
$vectorSearch = $vectorSearch.searchType(this.searchType);
$vectorSearch = $vectorSearch.withSearchScore(scoreField);
if (score != null) {
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
c.gt(score.getValue());
});
} else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) {
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
Range.Bound<Score> lower = distance.getLowerBound();
if (lower.isBounded()) {
double value = lower.getValue().get().getValue();
if (lower.isInclusive()) {
c.gte(value);
} else {
c.gt(value);
}
}
Range.Bound<Score> upper = distance.getUpperBound();
if (upper.isBounded()) {
double value = upper.getValue().get().getValue();
if (upper.isInclusive()) {
c.lte(value);
} else {
c.lt(value);
}
}
});
}
stages.add($vectorSearch);
if (input.query().isSorted()) {
stages.add(ctx -> {
Document mappedSort = ctx.getMappedObject(input.query().getSortObject(), outputType);
mappedSort.append(scoreField, -1);
return ctx.getMappedObject(new Document("$sort", mappedSort));
});
} else {
stages.add(Aggregation.sort(Sort.Direction.DESC, scoreField));
}
return new AggregationPipeline(stages);
}
private VectorSearchInput createSearchInput(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor,
ParameterBindingDocumentCodec codec, ParameterBindingContext context) {
VectorSearchInput input = queryFactory.createQuery(accessor, codec, context);
Limit limit = getLimit(evaluator, accessor);
if(!input.query.isLimited() || (input.query.isLimited() && !limit.isUnlimited())) {
input.query().limit(limit);
}
return input;
}
private Limit getLimit(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor) {
if (this.limitExpression != null) {
Object value = evaluator.evaluate(this.limitExpression);
if (value != null) {
if (value instanceof Limit l) {
return l;
}
if (value instanceof Number n) {
return Limit.of(n.intValue());
}
if (value instanceof String s) {
return Limit.of(NumberUtils.parseNumber(s, Integer.class));
}
throw new IllegalArgumentException("Invalid type for Limit. Found [%s], expected Limit or Number");
}
}
return new QueryMetadata(query.path, "__score__", query.query, searchType, outputType, numCandidates,
getSimilarityFunction(accessor));
if (this.limit.isLimited()) {
return this.limit;
}
return accessor.getLimit();
}
public String getQueryString() {
@ -192,82 +289,10 @@ class VectorSearchDelegate { @@ -192,82 +289,10 @@ class VectorSearchDelegate {
* @param query
* @param searchType
* @param outputType
* @param numCandidates
* @param scoringFunction
*/
public record QueryMetadata(String path, String scoreField, Query query, VectorSearchOperation.SearchType searchType,
Class<?> outputType, @Nullable Integer numCandidates, ScoringFunction scoringFunction) {
/**
* Create the Aggregation Pipeline.
*
* @param queryMethod
* @param accessor
* @return
*/
public List<AggregationOperation> getAggregationPipeline(MongoQueryMethod queryMethod,
MongoParameterAccessor accessor) {
Vector vector = accessor.getVector();
Score score = accessor.getScore();
Range<Score> distance = accessor.getScoreRange();
Limit limit = Limit.unlimited();
if (query.isLimited()) {
limit = Limit.of(query.getLimit());
}
List<AggregationOperation> stages = new ArrayList<>();
VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(queryMethod.getAnnotatedHint()).path(path())
.vector(vector).limit(limit);
if (numCandidates() != null) {
$vectorSearch = $vectorSearch.numCandidates(numCandidates());
}
$vectorSearch = $vectorSearch.filter(query.getQueryObject());
$vectorSearch = $vectorSearch.searchType(searchType());
$vectorSearch = $vectorSearch.withSearchScore(scoreField());
if (score != null) {
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
c.gt(score.getValue());
});
} else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) {
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
Range.Bound<Score> lower = distance.getLowerBound();
if (lower.isBounded()) {
double value = lower.getValue().get().getValue();
if (lower.isInclusive()) {
c.gte(value);
} else {
c.gt(value);
}
}
Range.Bound<Score> upper = distance.getUpperBound();
if (upper.isBounded()) {
double value = upper.getValue().get().getValue();
if (upper.isInclusive()) {
c.lte(value);
} else {
c.lt(value);
}
}
});
}
stages.add($vectorSearch);
if (query.isSorted()) {
// TODO stages.add(Aggregation.sort(query.with()));
} else {
stages.add(Aggregation.sort(Sort.Direction.DESC, "__score__"));
}
return stages;
}
record QueryContainer(String path, String scoreField, Query query, AggregationPipeline pipeline,
VectorSearchOperation.SearchType searchType, Class<?> outputType, ScoringFunction scoringFunction, String index) {
}
@ -368,11 +393,12 @@ class VectorSearchDelegate { @@ -368,11 +393,12 @@ class VectorSearchDelegate {
this.tree = tree;
}
@SuppressWarnings("NullAway")
public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec,
ParameterBindingContext context) {
MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(),
false, true);
MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), false,
true);
Query query = creator.createQuery(parameterAccessor.getSort());

3
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java

@ -81,16 +81,15 @@ public class VectorSearchTests { @@ -81,16 +81,15 @@ public class VectorSearchTests {
@Override
public MongoClient mongoClient() {
atlasLocal.start();
return MongoClients.create(atlasLocal.getConnectionString());
}
}
@BeforeAll
static void beforeAll() throws InterruptedException {
atlasLocal.start();
System.out.println(atlasLocal.getConnectionString());
client = MongoClients.create(atlasLocal.getConnectionString());
template = new MongoTestTemplate(client, "vector-search-tests");

3
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java

@ -34,6 +34,7 @@ import org.springframework.data.mongodb.core.convert.MappingMongoConverter; @@ -34,6 +34,7 @@ import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.repository.CrudRepository;
@ -68,7 +69,7 @@ class VectorSearchAggregationUnitTests { @@ -68,7 +69,7 @@ class VectorSearchAggregationUnitTests {
VectorSearchAggregation aggregation = aggregation(SampleRepository.class, "searchByCountryAndEmbeddingNear",
String.class, Vector.class, Score.class, Limit.class);
VectorSearchDelegate.QueryMetadata query = aggregation.createVectorSearchQuery(
QueryContainer query = aggregation.createVectorSearchQuery(
aggregation.getQueryMethod().getResultProcessor(),
new MongoParametersParameterAccessor(aggregation.getQueryMethod(),
new Object[] { "de", Vector.of(1f), Score.of(1), Limit.unlimited() }),

156
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java

@ -15,23 +15,30 @@ @@ -15,23 +15,30 @@
*/
package org.springframework.data.mongodb.repository.query;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.mock;
import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
import java.lang.reflect.Method;
import java.util.List;
import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.Test;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Score;
import org.springframework.data.domain.SearchResults;
import org.springframework.data.domain.Vector;
import org.springframework.data.mapping.model.ValueExpressionEvaluator;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.mongodb.util.aggregation.TestAggregationContext;
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
@ -44,6 +51,7 @@ import org.springframework.data.repository.query.ValueExpressionDelegate; @@ -44,6 +51,7 @@ import org.springframework.data.repository.query.ValueExpressionDelegate;
* Unit tests for {@link VectorSearchDelegate}.
*
* @author Mark Paluch
* @author Christoph Strobl
*/
class VectorSearchDelegateUnitTests {
@ -57,10 +65,10 @@ class VectorSearchDelegateUnitTests { @@ -57,10 +65,10 @@ class VectorSearchDelegateUnitTests {
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1));
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor);
QueryContainer container = createQueryContainer(queryMethod, accessor);
assertThat(query.query().getLimit()).isEqualTo(10);
assertThat(query.numCandidates()).isEqualTo(10 * 20);
assertThat(container.query().getLimit()).isEqualTo(10);
assertThat(numCandidates(container.pipeline())).isEqualTo(10 * 20);
}
@Test
@ -71,10 +79,10 @@ class VectorSearchDelegateUnitTests { @@ -71,10 +79,10 @@ class VectorSearchDelegateUnitTests {
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1));
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor);
QueryContainer container = createQueryContainer(queryMethod, accessor);
assertThat(query.query().getLimit()).isEqualTo(10);
assertThat(query.numCandidates()).isNull();
assertThat(container.query().getLimit()).isEqualTo(10);
assertThat(numCandidates(container.pipeline())).isNull();
}
@Test
@ -86,19 +94,87 @@ class VectorSearchDelegateUnitTests { @@ -86,19 +94,87 @@ class VectorSearchDelegateUnitTests {
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11));
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor);
QueryContainer container = createQueryContainer(queryMethod, accessor);
assertThat(container.query().getLimit()).isEqualTo(11);
assertThat(numCandidates(container.pipeline())).isEqualTo(11 * 20);
}
@Test
void considersDerivedQueryPart() throws ReflectiveOperationException {
Method method = VectorSearchRepository.class.getMethod("searchTop10ByFirstNameAndEmbeddingNear", String.class,
Vector.class, Score.class);
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, "spring", Vector.of(1, 2), Score.of(1));
QueryContainer container = createQueryContainer(queryMethod, accessor);
assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter",
new Document("first_name", "spring"));
}
@Test
void considersDerivedQueryPartInDifferentOrder() throws ReflectiveOperationException {
Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNearAndFirstName", Vector.class,
Score.class, String.class);
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), "spring");
QueryContainer container = createQueryContainer(queryMethod, accessor);
assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter",
new Document("first_name", "spring"));
}
@Test
void defaultSortsByScore() throws NoSuchMethodException {
Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class,
Limit.class);
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(10));
QueryContainer container = createQueryContainer(queryMethod, accessor);
List<Document> stages = container.pipeline().lastOperation()
.toPipelineStages(TestAggregationContext.contextFor(WithVector.class));
assertThat(stages).containsExactly(new Document("$sort", new Document("__score__", -1)));
}
@Test
void usesDerivedSort() throws NoSuchMethodException {
Method method = VectorSearchRepository.class.getMethod("searchByEmbeddingNearOrderByFirstName", Vector.class,
Score.class, Limit.class);
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11));
QueryContainer container = createQueryContainer(queryMethod, accessor);
AggregationPipeline aggregationPipeline = container.pipeline();
List<Document> stages = aggregationPipeline.lastOperation()
.toPipelineStages(TestAggregationContext.contextFor(WithVector.class));
assertThat(stages).containsExactly(new Document("$sort", new Document("first_name", 1).append("__score__", -1)));
}
assertThat(query.query().getLimit()).isEqualTo(11);
assertThat(query.numCandidates()).isEqualTo(11 * 20);
Document vectorSearchStageOf(AggregationPipeline pipeline) {
return pipeline.firstOperation().toPipelineStages(TestAggregationContext.contextFor(WithVector.class)).get(0);
}
private VectorSearchDelegate.QueryMetadata createQueryMetadata(MongoQueryMethod queryMethod,
MongoParametersParameterAccessor accessor) {
private QueryContainer createQueryContainer(MongoQueryMethod queryMethod, MongoParametersParameterAccessor accessor) {
VectorSearchDelegate delegate = new VectorSearchDelegate(queryMethod, converter, ValueExpressionDelegate.create());
return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor,
Object.class, new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class));
return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor, null,
new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class));
}
private MongoQueryMethod getMongoQueryMethod(Method method) {
@ -110,21 +186,69 @@ class VectorSearchDelegateUnitTests { @@ -110,21 +186,69 @@ class VectorSearchDelegateUnitTests {
return new MongoParametersParameterAccessor(queryMethod, values);
}
@Nullable
private static Integer numCandidates(AggregationPipeline pipeline) {
Document $vectorSearch = pipeline.firstOperation().toPipelineStages(Aggregation.DEFAULT_CONTEXT).get(0);
if ($vectorSearch.containsKey("$vectorSearch")) {
Object value = $vectorSearch.get("$vectorSearch", Document.class).get("numCandidates");
return value instanceof Number i ? i.intValue() : null;
}
return null;
}
interface VectorSearchRepository extends Repository<WithVector, String> {
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<WithVector> searchTop10ByEmbeddingNear(Vector vector, Score similarity);
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<WithVector> searchTop10ByFirstNameAndEmbeddingNear(String firstName, Vector vector, Score similarity);
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<WithVector> searchTop10ByEmbeddingNearAndFirstName(Vector vector, Score similarity, String firstname);
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ENN)
SearchResults<WithVector> searchTop10EnnByEmbeddingNear(Vector vector, Score similarity);
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<WithVector> searchTop10ByEmbeddingNear(Vector vector, Score similarity, Limit limit);
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<WithVector> searchByEmbeddingNearOrderByFirstName(Vector vector, Score similarity, Limit limit);
}
static class WithVector {
Vector embedding;
String lastName;
@Field("first_name") String firstName;
public Vector getEmbedding() {
return embedding;
}
public void setEmbedding(Vector embedding) {
this.embedding = embedding;
}
public String getLastName() {
return lastName;
}
public void setLastName(String lastName) {
this.lastName = lastName;
}
public String getFirstName() {
return firstName;
}
public void setFirstName(String firstName) {
this.firstName = firstName;
}
}
}

2
src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc

@ -25,7 +25,7 @@ Java:: @@ -25,7 +25,7 @@ Java::
[source,java,indent=0,subs="verbatim,quotes",role="primary"]
----
VectorIndex index = new VectorIndex("vector_index")
.addVector("plotEmbedding"), vector -> vector.dimensions(1536).similarity(COSINE)) <1>
.addVector("plotEmbedding", vector -> vector.dimensions(1536).similarity(COSINE)) <1>
.addFilter("year"); <2>
mongoTemplate.searchIndexOps(Movie.class) <3>

8
src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc

@ -6,13 +6,13 @@ Annotated search methods use the `@VectorSearch` annotation to define parameters @@ -6,13 +6,13 @@ Annotated search methods use the `@VectorSearch` annotation to define parameters
----
interface CommentRepository extends Repository<Comment, String> {
@VectorSearch(indexName = "cos-index", filter = "{country: ?0}")
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
@VectorSearch(indexName = "cos-index", filter = "{country: ?0}", limit="100", numCandidates="2000")
SearchResults<Comment> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
Score distance);
@VectorSearch(indexName = "my-index", filter = "{country: ?0}", numCandidates = "#{#limit * 20}",
@VectorSearch(indexName = "my-index", filter = "{country: ?0}", limit="?3", numCandidates = "#{#limit * 20}",
searchType = VectorSearchOperation.SearchType.ANN)
List<WithVector> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit);
List<Comment> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit);
}
----
====

12
src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc

@ -6,14 +6,14 @@ MongoDB Search methods must use the `@VectorSearch` annotation to define the ind @@ -6,14 +6,14 @@ MongoDB Search methods must use the `@VectorSearch` annotation to define the ind
----
interface CommentRepository extends Repository<Comment, String> {
@VectorSearch(indexName = "my-index")
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Score score);
@VectorSearch(indexName = "my-index", numCandidates="200")
SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Score score);
@VectorSearch(indexName = "my-index")
SearchResults<Comment> searchByEmbeddingWithin(Vector vector, Range<Similarity> range);
@VectorSearch(indexName = "my-index", numCandidates="200")
SearchResults<Comment> searchTop10ByEmbeddingWithin(Vector vector, Range<Similarity> range);
@VectorSearch(indexName = "my-index")
SearchResults<Comment> searchByCountryAndEmbeddingWithin(String country, Vector vector, Range<Similarity> range);
@VectorSearch(indexName = "my-index", numCandidates="200")
SearchResults<Comment> searchTop10ByCountryAndEmbeddingWithin(String country, Vector vector, Range<Similarity> range);
}
----
====

12
src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc

@ -4,12 +4,12 @@ @@ -4,12 +4,12 @@
----
interface CommentRepository extends Repository<Comment, String> {
@VectorSearch(indexName = "my-index")
@VectorSearch(indexName = "my-index", numCandidates="#{#limit.max() * 20}")
SearchResults<Comment> searchByCountryAndEmbeddingNear(String country, Vector vector, Score score,
Limit limit);
@VectorSearch(indexName = "my-index")
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
@VectorSearch(indexName = "my-index", limit="10", numCandidates="200")
SearchResults<Comment> searchByCountryAndEmbeddingWithin(String country, Vector embedding,
Score score);
}
@ -17,3 +17,9 @@ interface CommentRepository extends Repository<Comment, String> { @@ -17,3 +17,9 @@ interface CommentRepository extends Repository<Comment, String> {
SearchResults<Comment> results = repository.searchByCountryAndEmbeddingNear("en", Vector.of(…), Score.of(0.9), Limit.of(10));
----
====
[TIP]
====
The MongoDB https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/[vector search aggregation] stage defines a set of required arguments and restrictions.
Please make sure to follow the guidelines and make sure to provide required arguments like `limit`.
====

6
src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc

@ -9,13 +9,13 @@ The scoring function defaults to `ScoringFunction.unspecified()` as there is no @@ -9,13 +9,13 @@ The scoring function defaults to `ScoringFunction.unspecified()` as there is no
interface CommentRepository extends Repository<Comment, String> {
@VectorSearch(…)
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Score similarity);
SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Score similarity);
@VectorSearch(…)
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Similarity similarity);
SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Similarity similarity);
@VectorSearch(…)
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Range<Similarity> range);
SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Range<Similarity> range);
}
repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9)); <1>

Loading…
Cancel
Save