Browse Source

Polishing.

Original Pull Request: #4960
pull/4976/head
Christoph Strobl 8 months ago committed by Mark Paluch
parent
commit
21568c84eb
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 27
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java
  2. 1
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java
  3. 14
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java
  4. 1
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java
  5. 6
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java
  6. 9
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java
  7. 49
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java
  8. 67
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java
  9. 20
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java
  10. 8
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java
  11. 10
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java
  12. 252
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java
  13. 3
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java
  14. 3
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java
  15. 156
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java
  16. 2
      src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc
  17. 8
      src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc
  18. 12
      src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc
  19. 12
      src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc
  20. 6
      src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc

27
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

@ -486,7 +486,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
return doStream(query, entityType, collectionName, returnType, QueryResultConverter.entity()); return doStream(query, entityType, collectionName, returnType, QueryResultConverter.entity());
} }
@SuppressWarnings("ConstantConditions") @SuppressWarnings({"ConstantConditions", "NullAway"})
<T, R> Stream<R> doStream(Query query, Class<?> entityType, String collectionName, Class<T> returnType, <T, R> Stream<R> doStream(Query query, Class<?> entityType, String collectionName, Class<T> returnType,
QueryResultConverter<? super T, ? extends R> resultConverter) { QueryResultConverter<? super T, ? extends R> resultConverter) {
@ -1086,34 +1086,29 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
return new GeoResults<>(result, avgDistance); return new GeoResults<>(result, avgDistance);
} }
@Nullable public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass) {
@Override
public <T> T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass) {
return findAndModify(query, update, new FindAndModifyOptions(), entityClass, getCollectionName(entityClass)); return findAndModify(query, update, new FindAndModifyOptions(), entityClass, getCollectionName(entityClass));
} }
@Nullable
@Override @Override
public <T> T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass, public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass,
String collectionName) { String collectionName) {
return findAndModify(query, update, new FindAndModifyOptions(), entityClass, collectionName); return findAndModify(query, update, new FindAndModifyOptions(), entityClass, collectionName);
} }
@Nullable
@Override @Override
public <T> T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options, public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
Class<T> entityClass) { Class<T> entityClass) {
return findAndModify(query, update, options, entityClass, getCollectionName(entityClass)); return findAndModify(query, update, options, entityClass, getCollectionName(entityClass));
} }
@Nullable
@Override @Override
public <T> T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options, public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
Class<T> entityClass, String collectionName) { Class<T> entityClass, String collectionName) {
return findAndModify(query, update, options, entityClass, collectionName, QueryResultConverter.entity()); return findAndModify(query, update, options, entityClass, collectionName, QueryResultConverter.entity());
} }
<S, T> T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options, <S, T> @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
Class<S> entityClass, String collectionName, QueryResultConverter<? super S, ? extends T> resultConverter) { Class<S> entityClass, String collectionName, QueryResultConverter<? super S, ? extends T> resultConverter) {
Assert.notNull(query, "Query must not be null"); Assert.notNull(query, "Query must not be null");
@ -1185,15 +1180,13 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
// Find methods that take a Query to express the query and that return a single object that is also removed from the // Find methods that take a Query to express the query and that return a single object that is also removed from the
// collection in the database. // collection in the database.
@Nullable
@Override @Override
public <T> T findAndRemove(Query query, Class<T> entityClass) { public <T> @Nullable T findAndRemove(Query query, Class<T> entityClass) {
return findAndRemove(query, entityClass, getCollectionName(entityClass)); return findAndRemove(query, entityClass, getCollectionName(entityClass));
} }
@Nullable
@Override @Override
public <T> T findAndRemove(Query query, Class<T> entityClass, String collectionName) { public <T> @Nullable T findAndRemove(Query query, Class<T> entityClass, String collectionName) {
Assert.notNull(query, "Query must not be null"); Assert.notNull(query, "Query must not be null");
Assert.notNull(entityClass, "EntityClass must not be null"); Assert.notNull(entityClass, "EntityClass must not be null");
@ -2161,11 +2154,11 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
* @param entityClass * @param entityClass
* @return * @return
*/ */
@SuppressWarnings("NullAway")
protected <T> List<T> doFindAndDelete(String collectionName, Query query, Class<T> entityClass) { protected <T> List<T> doFindAndDelete(String collectionName, Query query, Class<T> entityClass) {
return doFindAndDelete(collectionName, query, entityClass, QueryResultConverter.entity()); return doFindAndDelete(collectionName, query, entityClass, QueryResultConverter.entity());
} }
@SuppressWarnings("NullAway")
<S, T> List<T> doFindAndDelete(String collectionName, Query query, Class<S> entityClass, <S, T> List<T> doFindAndDelete(String collectionName, Query query, Class<S> entityClass,
QueryResultConverter<? super S, ? extends T> resultConverter) { QueryResultConverter<? super S, ? extends T> resultConverter) {
@ -2229,7 +2222,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
return doAggregate(aggregation, collectionName, outputType, QueryResultConverter.entity(), context); return doAggregate(aggregation, collectionName, outputType, QueryResultConverter.entity(), context);
} }
@SuppressWarnings("ConstantConditions") @SuppressWarnings({"ConstantConditions", "NullAway"})
<T, O> AggregationResults<O> doAggregate(Aggregation aggregation, String collectionName, Class<T> outputType, <T, O> AggregationResults<O> doAggregate(Aggregation aggregation, String collectionName, Class<T> outputType,
QueryResultConverter<? super T, ? extends O> resultConverter, AggregationOperationContext context) { QueryResultConverter<? super T, ? extends O> resultConverter, AggregationOperationContext context) {

1
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java

@ -2293,6 +2293,7 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati
.flatMapSequential(deleteResult -> Flux.fromIterable(list))); .flatMapSequential(deleteResult -> Flux.fromIterable(list)));
} }
@SuppressWarnings({"rawtypes", "unchecked", "NullAway"})
<S, T> Flux<T> doFindAndDelete(String collectionName, Query query, Class<S> entityClass, <S, T> Flux<T> doFindAndDelete(String collectionName, Query query, Class<S> entityClass,
QueryResultConverter<? super S, ? extends T> resultConverter) { QueryResultConverter<? super S, ? extends T> resultConverter) {

14
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java

@ -22,8 +22,10 @@ import java.util.List;
import java.util.function.Predicate; import java.util.function.Predicate;
import org.bson.Document; import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.springframework.lang.Contract; import org.springframework.lang.Contract;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/** /**
* The {@link AggregationPipeline} holds the collection of {@link AggregationOperation aggregation stages}. * The {@link AggregationPipeline} holds the collection of {@link AggregationOperation aggregation stages}.
@ -82,6 +84,14 @@ public class AggregationPipeline {
return Collections.unmodifiableList(pipeline); return Collections.unmodifiableList(pipeline);
} }
public @Nullable AggregationOperation firstOperation() {
return CollectionUtils.firstElement(pipeline);
}
public @Nullable AggregationOperation lastOperation() {
return CollectionUtils.lastElement(pipeline);
}
List<Document> toDocuments(AggregationOperationContext context) { List<Document> toDocuments(AggregationOperationContext context) {
verify(); verify();
@ -97,8 +107,8 @@ public class AggregationPipeline {
return false; return false;
} }
AggregationOperation operation = pipeline.get(pipeline.size() - 1); AggregationOperation operation = lastOperation();
return isOut(operation) || isMerge(operation); return operation != null && (isOut(operation) || isMerge(operation));
} }
void verify() { void verify() {

1
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java

@ -356,6 +356,7 @@ public class ArrayOperators {
* @return new instance of {@link SortArray}. * @return new instance of {@link SortArray}.
* @since 4.5 * @since 4.5
*/ */
@SuppressWarnings("NullAway")
public SortArray sort(Direction direction) { public SortArray sort(Direction direction) {
if (usesFieldRef()) { if (usesFieldRef()) {

6
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java

@ -77,7 +77,7 @@ public class ConvertingParameterAccessor implements MongoParameterAccessor {
} }
@Override @Override
public Vector getVector() { public @Nullable Vector getVector() {
return delegate.getVector(); return delegate.getVector();
} }
@ -104,12 +104,12 @@ public class ConvertingParameterAccessor implements MongoParameterAccessor {
} }
@Override @Override
public @org.jspecify.annotations.Nullable Score getScore() { public @Nullable Score getScore() {
return delegate.getScore(); return delegate.getScore();
} }
@Override @Override
public @org.jspecify.annotations.Nullable Range<Score> getScoreRange() { public @Nullable Range<Score> getScoreRange() {
return delegate.getScoreRange(); return delegate.getScoreRange();
} }

9
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java

@ -61,14 +61,13 @@ public class MongoParametersParameterAccessor extends ParametersParameterAccesso
public Range<Score> getScoreRange() { public Range<Score> getScoreRange() {
MongoParameters mongoParameters = method.getParameters(); MongoParameters mongoParameters = method.getParameters();
int rangeIndex = mongoParameters.getScoreRangeIndex();
if (rangeIndex != -1) { if (mongoParameters.hasScoreRangeParameter()) {
return getValue(rangeIndex); return getValue(mongoParameters.getScoreRangeIndex());
} }
int scoreIndex = mongoParameters.getScoreIndex(); Score score = getScore();
Bound<Score> maxDistance = scoreIndex == -1 ? Bound.unbounded() : Bound.inclusive((Score) getScore()); Bound<Score> maxDistance = score != null ? Bound.inclusive(score) : Bound.unbounded();
return Range.of(Bound.unbounded(), maxDistance); return Range.of(Bound.unbounded(), maxDistance);
} }

49
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java

@ -15,7 +15,8 @@
*/ */
package org.springframework.data.mongodb.repository.query; package org.springframework.data.mongodb.repository.query;
import static org.springframework.data.mongodb.core.query.Criteria.*; import static org.springframework.data.mongodb.core.query.Criteria.Placeholder;
import static org.springframework.data.mongodb.core.query.Criteria.where;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
@ -27,7 +28,6 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.bson.BsonRegularExpression; import org.bson.BsonRegularExpression;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Range; import org.springframework.data.domain.Range;
import org.springframework.data.domain.Range.Bound; import org.springframework.data.domain.Range.Bound;
import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort;
@ -118,8 +118,9 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
return new Criteria(); return new Criteria();
} }
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { if (isPartOfSearchQuery(part)) {
return null; skip(part, iterator);
return new Criteria();
} }
PersistentPropertyPath<MongoPersistentProperty> path = context.getPersistentPropertyPath(part.getProperty()); PersistentPropertyPath<MongoPersistentProperty> path = context.getPersistentPropertyPath(part.getProperty());
@ -135,7 +136,8 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
return create(part, iterator); return create(part, iterator);
} }
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { if (isPartOfSearchQuery(part)) {
skip(part, iterator);
return base; return base;
} }
@ -176,15 +178,6 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
@SuppressWarnings("NullAway") @SuppressWarnings("NullAway")
private Criteria from(Part part, MongoPersistentProperty property, Criteria criteria, Iterator<Object> parameters) { private Criteria from(Part part, MongoPersistentProperty property, Criteria criteria, Iterator<Object> parameters) {
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) {
int numberOfArguments = part.getType().getNumberOfArguments();
for (int i = 0; i < numberOfArguments; i++) {
parameters.next();
}
return null;
}
Type type = part.getType(); Type type = part.getType();
switch (type) { switch (type) {
@ -206,13 +199,13 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
return criteria.is(null); return criteria.is(null);
case NOT_IN: case NOT_IN:
Object ninValue = parameters.next(); Object ninValue = parameters.next();
if(ninValue instanceof Placeholder) { if (ninValue instanceof Placeholder) {
return criteria.raw("$nin", ninValue); return criteria.raw("$nin", ninValue);
} }
return criteria.nin(valueAsList(ninValue, part)); return criteria.nin(valueAsList(ninValue, part));
case IN: case IN:
Object inValue = parameters.next(); Object inValue = parameters.next();
if(inValue instanceof Placeholder) { if (inValue instanceof Placeholder) {
return criteria.raw("$in", inValue); return criteria.raw("$in", inValue);
} }
return criteria.in(valueAsList(inValue, part)); return criteria.in(valueAsList(inValue, part));
@ -231,7 +224,7 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString()); return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString());
case EXISTS: case EXISTS:
Object next = parameters.next(); Object next = parameters.next();
if(next instanceof Placeholder placeholder) { if (next instanceof Placeholder placeholder) {
return criteria.raw("$exists", placeholder); return criteria.raw("$exists", placeholder);
} else { } else {
return criteria.exists((Boolean) next); return criteria.exists((Boolean) next);
@ -355,7 +348,7 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
if (property.isCollectionLike()) { if (property.isCollectionLike()) {
Object next = parameters.next(); Object next = parameters.next();
if(next instanceof Placeholder) { if (next instanceof Placeholder) {
return criteria.raw("$in", next); return criteria.raw("$in", next);
} }
return criteria.in(valueAsList(next, part)); return criteria.in(valueAsList(next, part));
@ -433,8 +426,7 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
streamable = streamable.map(it -> { streamable = streamable.map(it -> {
if (it instanceof String sv) { if (it instanceof String sv) {
return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), regexOptions);
regexOptions);
} }
return it; return it;
}); });
@ -468,10 +460,23 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
return false; return false;
} }
private boolean isPartOfSearchQuery(Part part) {
return isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN));
}
private static void skip(Part part, Iterator<?> parameters) {
int total = part.getNumberOfArguments();
int i = 0;
while (parameters.hasNext() && i < total) {
parameters.next();
i++;
}
}
/** /**
* Compute a {@link Type#BETWEEN} typed {@link Part} using {@link Criteria#gt(Object) $gt}, * Compute a {@link Type#BETWEEN} typed {@link Part} using {@link Criteria#gt(Object) $gt},
* {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}. * {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}. <br />
* <br />
* In case the first {@literal value} is actually a {@link Range} the lower and upper bounds of the {@link Range} are * In case the first {@literal value} is actually a {@link Range} the lower and upper bounds of the {@link Range} are
* used according to their {@link Bound#isInclusive() inclusion} definition. Otherwise the {@literal value} is used * used according to their {@link Bound#isInclusive() inclusion} definition. Otherwise the {@literal value} is used
* for {@literal $gt} and {@link Iterator#next() parameters.next()} as {@literal $lt}. * for {@literal $gt} and {@link Iterator#next() parameters.next()} as {@literal $lt}.

67
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java

@ -15,18 +15,16 @@
*/ */
package org.springframework.data.mongodb.repository.query; package org.springframework.data.mongodb.repository.query;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.function.Supplier; import java.util.function.Supplier;
import org.bson.Document;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Range; import org.springframework.data.domain.Range;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.SearchResult; import org.springframework.data.domain.SearchResult;
import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.SearchResults;
import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Similarity;
@ -37,6 +35,7 @@ import org.springframework.data.geo.GeoPage;
import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.GeoResult;
import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.GeoResults;
import org.springframework.data.geo.Point; import org.springframework.data.geo.Point;
import org.springframework.data.mongodb.core.ExecutableAggregationOperation.TerminatingAggregation;
import org.springframework.data.mongodb.core.ExecutableFindOperation; import org.springframework.data.mongodb.core.ExecutableFindOperation;
import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery;
import org.springframework.data.mongodb.core.ExecutableFindOperation.TerminatingFind; import org.springframework.data.mongodb.core.ExecutableFindOperation.TerminatingFind;
@ -45,12 +44,13 @@ import org.springframework.data.mongodb.core.ExecutableRemoveOperation.Executabl
import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove; import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove;
import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate;
import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation; import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.NearQuery;
import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.UpdateDefinition; import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.mongodb.repository.util.SliceUtils; import org.springframework.data.mongodb.repository.util.SliceUtils;
import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.support.PageableExecutionUtils; import org.springframework.data.support.PageableExecutionUtils;
@ -186,7 +186,7 @@ public interface MongoQueryExecution {
return isListOfGeoResult(method.getReturnType()) ? results.getContent() : results; return isListOfGeoResult(method.getReturnType()) ? results.getContent() : results;
} }
@SuppressWarnings({"unchecked","NullAway"}) @SuppressWarnings({ "unchecked", "NullAway" })
GeoResults<Object> doExecuteQuery(Query query) { GeoResults<Object> doExecuteQuery(Query query) {
Point nearLocation = accessor.getGeoNearLocation(); Point nearLocation = accessor.getGeoNearLocation();
@ -225,52 +225,53 @@ public interface MongoQueryExecution {
* {@link MongoQueryExecution} to execute vector search. * {@link MongoQueryExecution} to execute vector search.
* *
* @author Mark Paluch * @author Mark Paluch
* @author Chistoph Strobl
* @since 5.0 * @since 5.0
*/ */
class VectorSearchExecution implements MongoQueryExecution { class VectorSearchExecution implements MongoQueryExecution {
private final MongoOperations operations; private final MongoOperations operations;
private final MongoQueryMethod method; private final TypeInformation<?> returnType;
private final String collectionName; private final String collectionName;
private final VectorSearchDelegate.QueryMetadata queryMetadata; private final Class<?> targetType;
private final List<AggregationOperation> pipeline; private final ScoringFunction scoringFunction;
private final AggregationPipeline pipeline;
VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName,
QueryContainer queryContainer) {
this(operations, queryContainer.outputType(), collectionName, method.getReturnType(), queryContainer.pipeline(),
queryContainer.scoringFunction());
}
public VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName, public VectorSearchExecution(MongoOperations operations, Class<?> targetType, String collectionName,
VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) { TypeInformation<?> returnType, AggregationPipeline pipeline, ScoringFunction scoringFunction) {
this.operations = operations; this.operations = operations;
this.returnType = returnType;
this.collectionName = collectionName; this.collectionName = collectionName;
this.queryMetadata = queryMetadata; this.targetType = targetType;
this.method = method; this.scoringFunction = scoringFunction;
this.pipeline = queryMetadata.getAggregationPipeline(method, accessor); this.pipeline = pipeline;
} }
@Override @Override
@SuppressWarnings({ "unchecked", "rawtypes" })
public Object execute(Query query) { public Object execute(Query query) {
AggregationResults<?> aggregated = operations.aggregate( TerminatingAggregation<?> executableAggregation = operations.aggregateAndReturn(targetType)
TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collectionName, .inCollection(collectionName).by(TypedAggregation.newAggregation(targetType, pipeline.getOperations()));
queryMetadata.outputType());
List<?> mappedResults = aggregated.getMappedResults();
if (isSearchResult(method.getReturnType())) { if (!isSearchResult(returnType)) {
return executableAggregation.all().getMappedResults();
List<org.bson.Document> rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class);
List<SearchResult<Object>> result = new ArrayList<>(mappedResults.size());
for (int i = 0; i < mappedResults.size(); i++) {
Document document = rawResults.get(i);
SearchResult<Object> searchResult = new SearchResult<>(mappedResults.get(i),
Similarity.raw(document.getDouble("__score__"), queryMetadata.scoringFunction()));
result.add(searchResult);
}
return isListOfSearchResult(method.getReturnType()) ? result : new SearchResults<>(result);
} }
return mappedResults; AggregationResults<? extends SearchResult<?>> result = executableAggregation
.map((raw, container) -> new SearchResult<>(container.get(),
Similarity.raw(raw.getDouble("__score__"), scoringFunction)))
.all();
return isListOfSearchResult(returnType) ? result.getMappedResults()
: new SearchResults(result.getMappedResults());
} }
private static boolean isListOfSearchResult(TypeInformation<?> returnType) { private static boolean isListOfSearchResult(TypeInformation<?> returnType) {

20
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java

@ -18,12 +18,9 @@ package org.springframework.data.mongodb.repository.query;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.util.List;
import org.bson.Document; import org.bson.Document;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.data.convert.DtoInstantiatingConverter; import org.springframework.data.convert.DtoInstantiatingConverter;
import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Pageable;
@ -36,11 +33,12 @@ import org.springframework.data.geo.Point;
import org.springframework.data.mapping.model.EntityInstantiators; import org.springframework.data.mapping.model.EntityInstantiators;
import org.springframework.data.mongodb.core.ReactiveMongoOperations; import org.springframework.data.mongodb.core.ReactiveMongoOperations;
import org.springframework.data.mongodb.core.ReactiveUpdateOperation.ReactiveUpdate; import org.springframework.data.mongodb.core.ReactiveUpdateOperation.ReactiveUpdate;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation; import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.NearQuery;
import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.UpdateDefinition; import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.repository.query.ReturnedType; import org.springframework.data.repository.query.ReturnedType;
import org.springframework.data.util.ReactiveWrappers; import org.springframework.data.util.ReactiveWrappers;
@ -134,24 +132,24 @@ interface ReactiveMongoQueryExecution {
class VectorSearchExecution implements ReactiveMongoQueryExecution { class VectorSearchExecution implements ReactiveMongoQueryExecution {
private final ReactiveMongoOperations operations; private final ReactiveMongoOperations operations;
private final VectorSearchDelegate.QueryMetadata queryMetadata; private final QueryContainer queryMetadata;
private final List<AggregationOperation> pipeline; private final AggregationPipeline pipeline;
private final boolean returnSearchResult; private final boolean returnSearchResult;
public VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, QueryContainer queryMetadata) {
VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) {
this.operations = operations; this.operations = operations;
this.queryMetadata = queryMetadata; this.queryMetadata = queryMetadata;
this.pipeline = queryMetadata.getAggregationPipeline(method, accessor); this.pipeline = queryMetadata.pipeline();
this.returnSearchResult = isSearchResult(method.getReturnType()); this.returnSearchResult = isSearchResult(method.getReturnType());
} }
@Override @Override
public Publisher<? extends Object> execute(Query query, Class<?> type, String collection) { public Publisher<? extends Object> execute(Query query, Class<?> type, String collection) {
Flux<Document> aggregate = operations Flux<Document> aggregate = operations.aggregate(
.aggregate(TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collection, Document.class); TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline.getOperations()), collection,
Document.class);
return aggregate.map(document -> { return aggregate.map(document -> {

8
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java

@ -19,13 +19,13 @@ import reactor.core.publisher.Mono;
import org.bson.Document; import org.bson.Document;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.ReactiveMongoOperations; import org.springframework.data.mongodb.core.ReactiveMongoOperations;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.mongodb.util.json.ParameterBindingContext;
import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.repository.query.ValueExpressionDelegate; import org.springframework.data.repository.query.ValueExpressionDelegate;
@ -84,11 +84,11 @@ public class ReactiveVectorSearchAggregation extends AbstractReactiveMongoQuery
ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue, ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue,
expressionEvaluator); expressionEvaluator);
VectorSearchDelegate.QueryMetadata query = delegate.createQuery(expressionEvaluator, processor, accessor, QueryContainer query = delegate.createQuery(expressionEvaluator, processor, accessor, typeToRead, codec,
typeToRead, codec, bindingContext); bindingContext);
ReactiveMongoQueryExecution.VectorSearchExecution execution = new ReactiveMongoQueryExecution.VectorSearchExecution( ReactiveMongoQueryExecution.VectorSearchExecution execution = new ReactiveMongoQueryExecution.VectorSearchExecution(
mongoOperations, method, query, accessor); mongoOperations, method, query);
return execution.execute(query.query(), Document.class, collectionEntity.getCollection()); return execution.execute(query.query(), Document.class, collectionEntity.getCollection());
}); });

10
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java

@ -15,16 +15,17 @@
*/ */
package org.springframework.data.mongodb.repository.query; package org.springframework.data.mongodb.repository.query;
import org.jspecify.annotations.Nullable;
import org.springframework.data.mapping.model.ValueExpressionEvaluator; import org.springframework.data.mapping.model.ValueExpressionEvaluator;
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.mongodb.util.json.ParameterBindingContext;
import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.repository.query.ValueExpressionDelegate; import org.springframework.data.repository.query.ValueExpressionDelegate;
import org.springframework.lang.Nullable;
/** /**
* {@link AbstractMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either derived * {@link AbstractMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either derived
@ -62,20 +63,19 @@ public class VectorSearchAggregation extends AbstractMongoQuery {
this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate); this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate);
} }
@SuppressWarnings("unchecked")
@Override @Override
protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor, protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor,
@Nullable Class<?> typeToRead) { @Nullable Class<?> typeToRead) {
VectorSearchDelegate.QueryMetadata query = createVectorSearchQuery(processor, accessor, typeToRead); QueryContainer query = createVectorSearchQuery(processor, accessor, typeToRead);
MongoQueryExecution.VectorSearchExecution execution = new MongoQueryExecution.VectorSearchExecution(mongoOperations, MongoQueryExecution.VectorSearchExecution execution = new MongoQueryExecution.VectorSearchExecution(mongoOperations,
method, collectionEntity.getCollection(), query, accessor); method, collectionEntity.getCollection(), query);
return execution.execute(query.query()); return execution.execute(query.query());
} }
VectorSearchDelegate.QueryMetadata createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor, QueryContainer createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor,
@Nullable Class<?> typeToRead) { @Nullable Class<?> typeToRead) {
ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor); ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor);

252
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java

@ -20,7 +20,6 @@ import java.util.List;
import org.bson.Document; import org.bson.Document;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Limit; import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Range; import org.springframework.data.domain.Range;
import org.springframework.data.domain.Score; import org.springframework.data.domain.Score;
@ -34,6 +33,7 @@ import org.springframework.data.mapping.model.ValueExpressionEvaluator;
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation; import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
@ -47,6 +47,7 @@ import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.repository.query.ValueExpressionDelegate; import org.springframework.data.repository.query.ValueExpressionDelegate;
import org.springframework.data.repository.query.parser.Part; import org.springframework.data.repository.query.parser.Part;
import org.springframework.data.repository.query.parser.PartTree; import org.springframework.data.repository.query.parser.PartTree;
import org.springframework.util.NumberUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
/** /**
@ -58,32 +59,35 @@ class VectorSearchDelegate {
private final VectorSearchQueryFactory queryFactory; private final VectorSearchQueryFactory queryFactory;
private final VectorSearchOperation.SearchType searchType; private final VectorSearchOperation.SearchType searchType;
private final String indexName;
private final @Nullable Integer numCandidates; private final @Nullable Integer numCandidates;
private final @Nullable String numCandidatesExpression; private final @Nullable String numCandidatesExpression;
private final Limit limit; private final Limit limit;
private final @Nullable String limitExpression; private final @Nullable String limitExpression;
private final MongoConverter converter; private final MongoConverter converter;
public VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) { VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) {
VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow(); VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow();
this.searchType = vectorSearch.searchType(); this.searchType = vectorSearch.searchType();
this.indexName = method.getAnnotatedHint();
if (StringUtils.hasText(vectorSearch.numCandidates())) { if (StringUtils.hasText(vectorSearch.numCandidates())) {
ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates()); ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates());
if (expression.isLiteral()) { if (expression.isLiteral()) {
numCandidates = Integer.parseInt(vectorSearch.numCandidates()); this.numCandidates = Integer.parseInt(vectorSearch.numCandidates());
numCandidatesExpression = null; this.numCandidatesExpression = null;
} else { } else {
numCandidates = null; this.numCandidates = null;
numCandidatesExpression = vectorSearch.numCandidates(); this.numCandidatesExpression = vectorSearch.numCandidates();
} }
} else { } else {
numCandidates = null; this.numCandidates = null;
numCandidatesExpression = null; this.numCandidatesExpression = null;
} }
if (StringUtils.hasText(vectorSearch.limit())) { if (StringUtils.hasText(vectorSearch.limit())) {
@ -91,26 +95,26 @@ class VectorSearchDelegate {
ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit()); ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit());
if (expression.isLiteral()) { if (expression.isLiteral()) {
limit = Limit.of(Integer.parseInt(vectorSearch.limit())); this.limit = Limit.of(Integer.parseInt(vectorSearch.limit()));
limitExpression = null; this.limitExpression = null;
} else { } else {
limit = Limit.unlimited(); this.limit = Limit.unlimited();
limitExpression = vectorSearch.limit(); this.limitExpression = vectorSearch.limit();
} }
} else { } else {
limit = Limit.unlimited(); this.limit = Limit.unlimited();
limitExpression = null; this.limitExpression = null;
} }
this.converter = converter; this.converter = converter;
if (StringUtils.hasText(vectorSearch.filter())) { if (StringUtils.hasText(vectorSearch.filter())) {
queryFactory = StringUtils.hasText(vectorSearch.path()) this.queryFactory = StringUtils.hasText(vectorSearch.path())
? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path()) ? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path())
: new AnnotatedQueryFactory(vectorSearch.filter(), method.getEntityInformation().getCollectionEntity()); : new AnnotatedQueryFactory(vectorSearch.filter(), method.getEntityInformation().getCollectionEntity());
} else { } else {
queryFactory = new PartTreeQueryFactory( this.queryFactory = new PartTreeQueryFactory(
new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()), new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()),
converter.getMappingContext()); converter.getMappingContext());
} }
@ -119,43 +123,136 @@ class VectorSearchDelegate {
/** /**
* Create Query Metadata for {@code $vectorSearch}. * Create Query Metadata for {@code $vectorSearch}.
*/ */
public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor, QueryContainer createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor,
MongoParameterAccessor accessor, @Nullable Class<?> typeToRead, ParameterBindingDocumentCodec codec, MongoParameterAccessor accessor, @Nullable Class<?> typeToRead, ParameterBindingDocumentCodec codec,
ParameterBindingContext context) { ParameterBindingContext context) {
Integer numCandidates = null; String scoreField = "__score__";
Limit limit;
Class<?> outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType(); Class<?> outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType();
VectorSearchInput query = queryFactory.createQuery(accessor, codec, context); VectorSearchInput vectorSearchInput = createSearchInput(evaluator, accessor, codec, context);
AggregationPipeline pipeline = createVectorSearchPipeline(vectorSearchInput, scoreField, outputType, accessor,
evaluator);
if (this.limitExpression != null) { return new QueryContainer(vectorSearchInput.path, scoreField, vectorSearchInput.query, pipeline, searchType,
Object value = evaluator.evaluate(this.limitExpression); outputType, getSimilarityFunction(accessor), indexName);
limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue()); }
} else if (this.limit.isLimited()) {
limit = this.limit;
} else {
limit = accessor.getLimit();
}
if (limit.isLimited()) { @SuppressWarnings("NullAway")
query.query().limit(limit); AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class<?> outputType,
} MongoParameterAccessor accessor, ValueExpressionEvaluator evaluator) {
Vector vector = accessor.getVector();
Score score = accessor.getScore();
Range<Score> distance = accessor.getScoreRange();
Limit limit = Limit.of(input.query().getLimit());
List<AggregationOperation> stages = new ArrayList<>();
VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(indexName).path(input.path()).vector(vector)
.limit(limit);
Integer candidates = null;
if (this.numCandidatesExpression != null) { if (this.numCandidatesExpression != null) {
numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); candidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue();
} else if (this.numCandidates != null) { } else if (this.numCandidates != null) {
numCandidates = this.numCandidates; candidates = this.numCandidates;
} else if (query.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN } else if (input.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN
|| searchType == VectorSearchOperation.SearchType.DEFAULT)) { || searchType == VectorSearchOperation.SearchType.DEFAULT)) {
/* /*
MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy. MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy.
*/ */
numCandidates = query.query().getLimit() * 20; candidates = input.query().getLimit() * 20;
} }
return new QueryMetadata(query.path, "__score__", query.query, searchType, outputType, numCandidates, if (candidates != null) {
getSimilarityFunction(accessor)); $vectorSearch = $vectorSearch.numCandidates(candidates);
}
//
$vectorSearch = $vectorSearch.filter(input.query.getQueryObject());
$vectorSearch = $vectorSearch.searchType(this.searchType);
$vectorSearch = $vectorSearch.withSearchScore(scoreField);
if (score != null) {
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
c.gt(score.getValue());
});
} else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) {
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
Range.Bound<Score> lower = distance.getLowerBound();
if (lower.isBounded()) {
double value = lower.getValue().get().getValue();
if (lower.isInclusive()) {
c.gte(value);
} else {
c.gt(value);
}
}
Range.Bound<Score> upper = distance.getUpperBound();
if (upper.isBounded()) {
double value = upper.getValue().get().getValue();
if (upper.isInclusive()) {
c.lte(value);
} else {
c.lt(value);
}
}
});
}
stages.add($vectorSearch);
if (input.query().isSorted()) {
stages.add(ctx -> {
Document mappedSort = ctx.getMappedObject(input.query().getSortObject(), outputType);
mappedSort.append(scoreField, -1);
return ctx.getMappedObject(new Document("$sort", mappedSort));
});
} else {
stages.add(Aggregation.sort(Sort.Direction.DESC, scoreField));
}
return new AggregationPipeline(stages);
}
private VectorSearchInput createSearchInput(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor,
ParameterBindingDocumentCodec codec, ParameterBindingContext context) {
VectorSearchInput input = queryFactory.createQuery(accessor, codec, context);
Limit limit = getLimit(evaluator, accessor);
if(!input.query.isLimited() || (input.query.isLimited() && !limit.isUnlimited())) {
input.query().limit(limit);
}
return input;
}
private Limit getLimit(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor) {
if (this.limitExpression != null) {
Object value = evaluator.evaluate(this.limitExpression);
if (value != null) {
if (value instanceof Limit l) {
return l;
}
if (value instanceof Number n) {
return Limit.of(n.intValue());
}
if (value instanceof String s) {
return Limit.of(NumberUtils.parseNumber(s, Integer.class));
}
throw new IllegalArgumentException("Invalid type for Limit. Found [%s], expected Limit or Number");
}
}
if (this.limit.isLimited()) {
return this.limit;
}
return accessor.getLimit();
} }
public String getQueryString() { public String getQueryString() {
@ -192,82 +289,10 @@ class VectorSearchDelegate {
* @param query * @param query
* @param searchType * @param searchType
* @param outputType * @param outputType
* @param numCandidates
* @param scoringFunction * @param scoringFunction
*/ */
public record QueryMetadata(String path, String scoreField, Query query, VectorSearchOperation.SearchType searchType, record QueryContainer(String path, String scoreField, Query query, AggregationPipeline pipeline,
Class<?> outputType, @Nullable Integer numCandidates, ScoringFunction scoringFunction) { VectorSearchOperation.SearchType searchType, Class<?> outputType, ScoringFunction scoringFunction, String index) {
/**
* Create the Aggregation Pipeline.
*
* @param queryMethod
* @param accessor
* @return
*/
public List<AggregationOperation> getAggregationPipeline(MongoQueryMethod queryMethod,
MongoParameterAccessor accessor) {
Vector vector = accessor.getVector();
Score score = accessor.getScore();
Range<Score> distance = accessor.getScoreRange();
Limit limit = Limit.unlimited();
if (query.isLimited()) {
limit = Limit.of(query.getLimit());
}
List<AggregationOperation> stages = new ArrayList<>();
VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(queryMethod.getAnnotatedHint()).path(path())
.vector(vector).limit(limit);
if (numCandidates() != null) {
$vectorSearch = $vectorSearch.numCandidates(numCandidates());
}
$vectorSearch = $vectorSearch.filter(query.getQueryObject());
$vectorSearch = $vectorSearch.searchType(searchType());
$vectorSearch = $vectorSearch.withSearchScore(scoreField());
if (score != null) {
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
c.gt(score.getValue());
});
} else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) {
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
Range.Bound<Score> lower = distance.getLowerBound();
if (lower.isBounded()) {
double value = lower.getValue().get().getValue();
if (lower.isInclusive()) {
c.gte(value);
} else {
c.gt(value);
}
}
Range.Bound<Score> upper = distance.getUpperBound();
if (upper.isBounded()) {
double value = upper.getValue().get().getValue();
if (upper.isInclusive()) {
c.lte(value);
} else {
c.lt(value);
}
}
});
}
stages.add($vectorSearch);
if (query.isSorted()) {
// TODO stages.add(Aggregation.sort(query.with()));
} else {
stages.add(Aggregation.sort(Sort.Direction.DESC, "__score__"));
}
return stages;
}
} }
@ -368,11 +393,12 @@ class VectorSearchDelegate {
this.tree = tree; this.tree = tree;
} }
@SuppressWarnings("NullAway")
public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec,
ParameterBindingContext context) { ParameterBindingContext context) {
MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), false,
false, true); true);
Query query = creator.createQuery(parameterAccessor.getSort()); Query query = creator.createQuery(parameterAccessor.getSort());

3
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java

@ -81,16 +81,15 @@ public class VectorSearchTests {
@Override @Override
public MongoClient mongoClient() { public MongoClient mongoClient() {
atlasLocal.start();
return MongoClients.create(atlasLocal.getConnectionString()); return MongoClients.create(atlasLocal.getConnectionString());
} }
} }
@BeforeAll @BeforeAll
static void beforeAll() throws InterruptedException { static void beforeAll() throws InterruptedException {
atlasLocal.start(); atlasLocal.start();
System.out.println(atlasLocal.getConnectionString());
client = MongoClients.create(atlasLocal.getConnectionString()); client = MongoClients.create(atlasLocal.getConnectionString());
template = new MongoTestTemplate(client, "vector-search-tests"); template = new MongoTestTemplate(client, "vector-search-tests");

3
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java

@ -34,6 +34,7 @@ import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.CrudRepository;
@ -68,7 +69,7 @@ class VectorSearchAggregationUnitTests {
VectorSearchAggregation aggregation = aggregation(SampleRepository.class, "searchByCountryAndEmbeddingNear", VectorSearchAggregation aggregation = aggregation(SampleRepository.class, "searchByCountryAndEmbeddingNear",
String.class, Vector.class, Score.class, Limit.class); String.class, Vector.class, Score.class, Limit.class);
VectorSearchDelegate.QueryMetadata query = aggregation.createVectorSearchQuery( QueryContainer query = aggregation.createVectorSearchQuery(
aggregation.getQueryMethod().getResultProcessor(), aggregation.getQueryMethod().getResultProcessor(),
new MongoParametersParameterAccessor(aggregation.getQueryMethod(), new MongoParametersParameterAccessor(aggregation.getQueryMethod(),
new Object[] { "de", Vector.of(1f), Score.of(1), Limit.unlimited() }), new Object[] { "de", Vector.of(1f), Score.of(1), Limit.unlimited() }),

156
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java

@ -15,23 +15,30 @@
*/ */
package org.springframework.data.mongodb.repository.query; package org.springframework.data.mongodb.repository.query;
import static org.assertj.core.api.Assertions.*; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.*; import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.List;
import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.data.domain.Limit; import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Score; import org.springframework.data.domain.Score;
import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.SearchResults;
import org.springframework.data.domain.Vector; import org.springframework.data.domain.Vector;
import org.springframework.data.mapping.model.ValueExpressionEvaluator; import org.springframework.data.mapping.model.ValueExpressionEvaluator;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
import org.springframework.data.mongodb.util.aggregation.TestAggregationContext;
import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.mongodb.util.json.ParameterBindingContext;
import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
@ -44,6 +51,7 @@ import org.springframework.data.repository.query.ValueExpressionDelegate;
* Unit tests for {@link VectorSearchDelegate}. * Unit tests for {@link VectorSearchDelegate}.
* *
* @author Mark Paluch * @author Mark Paluch
* @author Christoph Strobl
*/ */
class VectorSearchDelegateUnitTests { class VectorSearchDelegateUnitTests {
@ -57,10 +65,10 @@ class VectorSearchDelegateUnitTests {
MongoQueryMethod queryMethod = getMongoQueryMethod(method); MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1)); MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1));
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); QueryContainer container = createQueryContainer(queryMethod, accessor);
assertThat(query.query().getLimit()).isEqualTo(10); assertThat(container.query().getLimit()).isEqualTo(10);
assertThat(query.numCandidates()).isEqualTo(10 * 20); assertThat(numCandidates(container.pipeline())).isEqualTo(10 * 20);
} }
@Test @Test
@ -71,10 +79,10 @@ class VectorSearchDelegateUnitTests {
MongoQueryMethod queryMethod = getMongoQueryMethod(method); MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1)); MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1));
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); QueryContainer container = createQueryContainer(queryMethod, accessor);
assertThat(query.query().getLimit()).isEqualTo(10); assertThat(container.query().getLimit()).isEqualTo(10);
assertThat(query.numCandidates()).isNull(); assertThat(numCandidates(container.pipeline())).isNull();
} }
@Test @Test
@ -86,19 +94,87 @@ class VectorSearchDelegateUnitTests {
MongoQueryMethod queryMethod = getMongoQueryMethod(method); MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11)); MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11));
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); QueryContainer container = createQueryContainer(queryMethod, accessor);
assertThat(query.query().getLimit()).isEqualTo(11); assertThat(container.query().getLimit()).isEqualTo(11);
assertThat(query.numCandidates()).isEqualTo(11 * 20); assertThat(numCandidates(container.pipeline())).isEqualTo(11 * 20);
} }
private VectorSearchDelegate.QueryMetadata createQueryMetadata(MongoQueryMethod queryMethod, @Test
MongoParametersParameterAccessor accessor) { void considersDerivedQueryPart() throws ReflectiveOperationException {
Method method = VectorSearchRepository.class.getMethod("searchTop10ByFirstNameAndEmbeddingNear", String.class,
Vector.class, Score.class);
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, "spring", Vector.of(1, 2), Score.of(1));
QueryContainer container = createQueryContainer(queryMethod, accessor);
assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter",
new Document("first_name", "spring"));
}
@Test
void considersDerivedQueryPartInDifferentOrder() throws ReflectiveOperationException {
Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNearAndFirstName", Vector.class,
Score.class, String.class);
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), "spring");
QueryContainer container = createQueryContainer(queryMethod, accessor);
assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter",
new Document("first_name", "spring"));
}
@Test
void defaultSortsByScore() throws NoSuchMethodException {
Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class,
Limit.class);
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(10));
QueryContainer container = createQueryContainer(queryMethod, accessor);
List<Document> stages = container.pipeline().lastOperation()
.toPipelineStages(TestAggregationContext.contextFor(WithVector.class));
assertThat(stages).containsExactly(new Document("$sort", new Document("__score__", -1)));
}
@Test
void usesDerivedSort() throws NoSuchMethodException {
Method method = VectorSearchRepository.class.getMethod("searchByEmbeddingNearOrderByFirstName", Vector.class,
Score.class, Limit.class);
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11));
QueryContainer container = createQueryContainer(queryMethod, accessor);
AggregationPipeline aggregationPipeline = container.pipeline();
List<Document> stages = aggregationPipeline.lastOperation()
.toPipelineStages(TestAggregationContext.contextFor(WithVector.class));
assertThat(stages).containsExactly(new Document("$sort", new Document("first_name", 1).append("__score__", -1)));
}
Document vectorSearchStageOf(AggregationPipeline pipeline) {
return pipeline.firstOperation().toPipelineStages(TestAggregationContext.contextFor(WithVector.class)).get(0);
}
private QueryContainer createQueryContainer(MongoQueryMethod queryMethod, MongoParametersParameterAccessor accessor) {
VectorSearchDelegate delegate = new VectorSearchDelegate(queryMethod, converter, ValueExpressionDelegate.create()); VectorSearchDelegate delegate = new VectorSearchDelegate(queryMethod, converter, ValueExpressionDelegate.create());
return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor, return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor, null,
Object.class, new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class)); new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class));
} }
private MongoQueryMethod getMongoQueryMethod(Method method) { private MongoQueryMethod getMongoQueryMethod(Method method) {
@ -110,21 +186,69 @@ class VectorSearchDelegateUnitTests {
return new MongoParametersParameterAccessor(queryMethod, values); return new MongoParametersParameterAccessor(queryMethod, values);
} }
@Nullable
private static Integer numCandidates(AggregationPipeline pipeline) {
Document $vectorSearch = pipeline.firstOperation().toPipelineStages(Aggregation.DEFAULT_CONTEXT).get(0);
if ($vectorSearch.containsKey("$vectorSearch")) {
Object value = $vectorSearch.get("$vectorSearch", Document.class).get("numCandidates");
return value instanceof Number i ? i.intValue() : null;
}
return null;
}
interface VectorSearchRepository extends Repository<WithVector, String> { interface VectorSearchRepository extends Repository<WithVector, String> {
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<WithVector> searchTop10ByEmbeddingNear(Vector vector, Score similarity); SearchResults<WithVector> searchTop10ByEmbeddingNear(Vector vector, Score similarity);
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<WithVector> searchTop10ByFirstNameAndEmbeddingNear(String firstName, Vector vector, Score similarity);
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<WithVector> searchTop10ByEmbeddingNearAndFirstName(Vector vector, Score similarity, String firstname);
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ENN) @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ENN)
SearchResults<WithVector> searchTop10EnnByEmbeddingNear(Vector vector, Score similarity); SearchResults<WithVector> searchTop10EnnByEmbeddingNear(Vector vector, Score similarity);
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<WithVector> searchTop10ByEmbeddingNear(Vector vector, Score similarity, Limit limit); SearchResults<WithVector> searchTop10ByEmbeddingNear(Vector vector, Score similarity, Limit limit);
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
SearchResults<WithVector> searchByEmbeddingNearOrderByFirstName(Vector vector, Score similarity, Limit limit);
} }
static class WithVector { static class WithVector {
Vector embedding; Vector embedding;
String lastName;
@Field("first_name") String firstName;
public Vector getEmbedding() {
return embedding;
}
public void setEmbedding(Vector embedding) {
this.embedding = embedding;
}
public String getLastName() {
return lastName;
}
public void setLastName(String lastName) {
this.lastName = lastName;
}
public String getFirstName() {
return firstName;
}
public void setFirstName(String firstName) {
this.firstName = firstName;
}
} }
} }

2
src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc

@ -25,7 +25,7 @@ Java::
[source,java,indent=0,subs="verbatim,quotes",role="primary"] [source,java,indent=0,subs="verbatim,quotes",role="primary"]
---- ----
VectorIndex index = new VectorIndex("vector_index") VectorIndex index = new VectorIndex("vector_index")
.addVector("plotEmbedding"), vector -> vector.dimensions(1536).similarity(COSINE)) <1> .addVector("plotEmbedding", vector -> vector.dimensions(1536).similarity(COSINE)) <1>
.addFilter("year"); <2> .addFilter("year"); <2>
mongoTemplate.searchIndexOps(Movie.class) <3> mongoTemplate.searchIndexOps(Movie.class) <3>

8
src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc

@ -6,13 +6,13 @@ Annotated search methods use the `@VectorSearch` annotation to define parameters
---- ----
interface CommentRepository extends Repository<Comment, String> { interface CommentRepository extends Repository<Comment, String> {
@VectorSearch(indexName = "cos-index", filter = "{country: ?0}") @VectorSearch(indexName = "cos-index", filter = "{country: ?0}", limit="100", numCandidates="2000")
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, SearchResults<Comment> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
Score distance); Score distance);
@VectorSearch(indexName = "my-index", filter = "{country: ?0}", numCandidates = "#{#limit * 20}", @VectorSearch(indexName = "my-index", filter = "{country: ?0}", limit="?3", numCandidates = "#{#limit * 20}",
searchType = VectorSearchOperation.SearchType.ANN) searchType = VectorSearchOperation.SearchType.ANN)
List<WithVector> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit); List<Comment> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit);
} }
---- ----
==== ====

12
src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc

@ -6,14 +6,14 @@ MongoDB Search methods must use the `@VectorSearch` annotation to define the ind
---- ----
interface CommentRepository extends Repository<Comment, String> { interface CommentRepository extends Repository<Comment, String> {
@VectorSearch(indexName = "my-index") @VectorSearch(indexName = "my-index", numCandidates="200")
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Score score); SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Score score);
@VectorSearch(indexName = "my-index") @VectorSearch(indexName = "my-index", numCandidates="200")
SearchResults<Comment> searchByEmbeddingWithin(Vector vector, Range<Similarity> range); SearchResults<Comment> searchTop10ByEmbeddingWithin(Vector vector, Range<Similarity> range);
@VectorSearch(indexName = "my-index") @VectorSearch(indexName = "my-index", numCandidates="200")
SearchResults<Comment> searchByCountryAndEmbeddingWithin(String country, Vector vector, Range<Similarity> range); SearchResults<Comment> searchTop10ByCountryAndEmbeddingWithin(String country, Vector vector, Range<Similarity> range);
} }
---- ----
==== ====

12
src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc

@ -4,12 +4,12 @@
---- ----
interface CommentRepository extends Repository<Comment, String> { interface CommentRepository extends Repository<Comment, String> {
@VectorSearch(indexName = "my-index") @VectorSearch(indexName = "my-index", numCandidates="#{#limit.max() * 20}")
SearchResults<Comment> searchByCountryAndEmbeddingNear(String country, Vector vector, Score score, SearchResults<Comment> searchByCountryAndEmbeddingNear(String country, Vector vector, Score score,
Limit limit); Limit limit);
@VectorSearch(indexName = "my-index") @VectorSearch(indexName = "my-index", limit="10", numCandidates="200")
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, SearchResults<Comment> searchByCountryAndEmbeddingWithin(String country, Vector embedding,
Score score); Score score);
} }
@ -17,3 +17,9 @@ interface CommentRepository extends Repository<Comment, String> {
SearchResults<Comment> results = repository.searchByCountryAndEmbeddingNear("en", Vector.of(…), Score.of(0.9), Limit.of(10)); SearchResults<Comment> results = repository.searchByCountryAndEmbeddingNear("en", Vector.of(…), Score.of(0.9), Limit.of(10));
---- ----
==== ====
[TIP]
====
The MongoDB https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/[vector search aggregation] stage defines a set of required arguments and restrictions.
Please make sure to follow the guidelines and make sure to provide required arguments like `limit`.
====

6
src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc

@ -9,13 +9,13 @@ The scoring function defaults to `ScoringFunction.unspecified()` as there is no
interface CommentRepository extends Repository<Comment, String> { interface CommentRepository extends Repository<Comment, String> {
@VectorSearch(…) @VectorSearch(…)
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Score similarity); SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Score similarity);
@VectorSearch(…) @VectorSearch(…)
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Similarity similarity); SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Similarity similarity);
@VectorSearch(…) @VectorSearch(…)
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Range<Similarity> range); SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Range<Similarity> range);
} }
repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9)); <1> repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9)); <1>

Loading…
Cancel
Save