@ -20,7 +20,6 @@ import java.util.List;
import org.bson.Document ;
import org.bson.Document ;
import org.jspecify.annotations.Nullable ;
import org.jspecify.annotations.Nullable ;
import org.springframework.data.domain.Limit ;
import org.springframework.data.domain.Limit ;
import org.springframework.data.domain.Range ;
import org.springframework.data.domain.Range ;
import org.springframework.data.domain.Score ;
import org.springframework.data.domain.Score ;
@ -34,6 +33,7 @@ import org.springframework.data.mapping.model.ValueExpressionEvaluator;
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException ;
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException ;
import org.springframework.data.mongodb.core.aggregation.Aggregation ;
import org.springframework.data.mongodb.core.aggregation.Aggregation ;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation ;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation ;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline ;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation ;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation ;
import org.springframework.data.mongodb.core.convert.MongoConverter ;
import org.springframework.data.mongodb.core.convert.MongoConverter ;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity ;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity ;
@ -47,6 +47,7 @@ import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.repository.query.ValueExpressionDelegate ;
import org.springframework.data.repository.query.ValueExpressionDelegate ;
import org.springframework.data.repository.query.parser.Part ;
import org.springframework.data.repository.query.parser.Part ;
import org.springframework.data.repository.query.parser.PartTree ;
import org.springframework.data.repository.query.parser.PartTree ;
import org.springframework.util.NumberUtils ;
import org.springframework.util.StringUtils ;
import org.springframework.util.StringUtils ;
/ * *
/ * *
@ -58,32 +59,35 @@ class VectorSearchDelegate {
private final VectorSearchQueryFactory queryFactory ;
private final VectorSearchQueryFactory queryFactory ;
private final VectorSearchOperation . SearchType searchType ;
private final VectorSearchOperation . SearchType searchType ;
private final String indexName ;
private final @Nullable Integer numCandidates ;
private final @Nullable Integer numCandidates ;
private final @Nullable String numCandidatesExpression ;
private final @Nullable String numCandidatesExpression ;
private final Limit limit ;
private final Limit limit ;
private final @Nullable String limitExpression ;
private final @Nullable String limitExpression ;
private final MongoConverter converter ;
private final MongoConverter converter ;
public VectorSearchDelegate ( MongoQueryMethod method , MongoConverter converter , ValueExpressionDelegate delegate ) {
VectorSearchDelegate ( MongoQueryMethod method , MongoConverter converter , ValueExpressionDelegate delegate ) {
VectorSearch vectorSearch = method . findAnnotatedVectorSearch ( ) . orElseThrow ( ) ;
VectorSearch vectorSearch = method . findAnnotatedVectorSearch ( ) . orElseThrow ( ) ;
this . searchType = vectorSearch . searchType ( ) ;
this . searchType = vectorSearch . searchType ( ) ;
this . indexName = method . getAnnotatedHint ( ) ;
if ( StringUtils . hasText ( vectorSearch . numCandidates ( ) ) ) {
if ( StringUtils . hasText ( vectorSearch . numCandidates ( ) ) ) {
ValueExpression expression = delegate . getValueExpressionParser ( ) . parse ( vectorSearch . numCandidates ( ) ) ;
ValueExpression expression = delegate . getValueExpressionParser ( ) . parse ( vectorSearch . numCandidates ( ) ) ;
if ( expression . isLiteral ( ) ) {
if ( expression . isLiteral ( ) ) {
numCandidates = Integer . parseInt ( vectorSearch . numCandidates ( ) ) ;
this . numCandidates = Integer . parseInt ( vectorSearch . numCandidates ( ) ) ;
numCandidatesExpression = null ;
this . numCandidatesExpression = null ;
} else {
} else {
numCandidates = null ;
this . numCandidates = null ;
numCandidatesExpression = vectorSearch . numCandidates ( ) ;
this . numCandidatesExpression = vectorSearch . numCandidates ( ) ;
}
}
} else {
} else {
numCandidates = null ;
this . numCandidates = null ;
numCandidatesExpression = null ;
this . numCandidatesExpression = null ;
}
}
if ( StringUtils . hasText ( vectorSearch . limit ( ) ) ) {
if ( StringUtils . hasText ( vectorSearch . limit ( ) ) ) {
@ -91,26 +95,26 @@ class VectorSearchDelegate {
ValueExpression expression = delegate . getValueExpressionParser ( ) . parse ( vectorSearch . limit ( ) ) ;
ValueExpression expression = delegate . getValueExpressionParser ( ) . parse ( vectorSearch . limit ( ) ) ;
if ( expression . isLiteral ( ) ) {
if ( expression . isLiteral ( ) ) {
limit = Limit . of ( Integer . parseInt ( vectorSearch . limit ( ) ) ) ;
this . limit = Limit . of ( Integer . parseInt ( vectorSearch . limit ( ) ) ) ;
limitExpression = null ;
this . limitExpression = null ;
} else {
} else {
limit = Limit . unlimited ( ) ;
this . limit = Limit . unlimited ( ) ;
limitExpression = vectorSearch . limit ( ) ;
this . limitExpression = vectorSearch . limit ( ) ;
}
}
} else {
} else {
limit = Limit . unlimited ( ) ;
this . limit = Limit . unlimited ( ) ;
limitExpression = null ;
this . limitExpression = null ;
}
}
this . converter = converter ;
this . converter = converter ;
if ( StringUtils . hasText ( vectorSearch . filter ( ) ) ) {
if ( StringUtils . hasText ( vectorSearch . filter ( ) ) ) {
queryFactory = StringUtils . hasText ( vectorSearch . path ( ) )
this . queryFactory = StringUtils . hasText ( vectorSearch . path ( ) )
? new AnnotatedQueryFactory ( vectorSearch . filter ( ) , vectorSearch . path ( ) )
? new AnnotatedQueryFactory ( vectorSearch . filter ( ) , vectorSearch . path ( ) )
: new AnnotatedQueryFactory ( vectorSearch . filter ( ) , method . getEntityInformation ( ) . getCollectionEntity ( ) ) ;
: new AnnotatedQueryFactory ( vectorSearch . filter ( ) , method . getEntityInformation ( ) . getCollectionEntity ( ) ) ;
} else {
} else {
queryFactory = new PartTreeQueryFactory (
this . queryFactory = new PartTreeQueryFactory (
new PartTree ( method . getName ( ) , method . getResultProcessor ( ) . getReturnedType ( ) . getDomainType ( ) ) ,
new PartTree ( method . getName ( ) , method . getResultProcessor ( ) . getReturnedType ( ) . getDomainType ( ) ) ,
converter . getMappingContext ( ) ) ;
converter . getMappingContext ( ) ) ;
}
}
@ -119,43 +123,136 @@ class VectorSearchDelegate {
/ * *
/ * *
* Create Query Metadata for { @code $vectorSearch } .
* Create Query Metadata for { @code $vectorSearch } .
* /
* /
public QueryMetadata createQuery ( ValueExpressionEvaluator evaluator , ResultProcessor processor ,
QueryContainer createQuery ( ValueExpressionEvaluator evaluator , ResultProcessor processor ,
MongoParameterAccessor accessor , @Nullable Class < ? > typeToRead , ParameterBindingDocumentCodec codec ,
MongoParameterAccessor accessor , @Nullable Class < ? > typeToRead , ParameterBindingDocumentCodec codec ,
ParameterBindingContext context ) {
ParameterBindingContext context ) {
Integer numCandidates = null ;
String scoreField = "__score__" ;
Limit limit ;
Class < ? > outputType = typeToRead ! = null ? typeToRead : processor . getReturnedType ( ) . getReturnedType ( ) ;
Class < ? > outputType = typeToRead ! = null ? typeToRead : processor . getReturnedType ( ) . getReturnedType ( ) ;
VectorSearchInput query = queryFactory . createQuery ( accessor , codec , context ) ;
VectorSearchInput vectorSearchInput = createSearchInput ( evaluator , accessor , codec , context ) ;
AggregationPipeline pipeline = createVectorSearchPipeline ( vectorSearchInput , scoreField , outputType , accessor ,
evaluator ) ;
if ( this . limitExpression ! = null ) {
return new QueryContainer ( vectorSearchInput . path , scoreField , vectorSearchInput . query , pipeline , searchType ,
Object value = evaluator . evaluate ( this . limitExpression ) ;
outputType , getSimilarityFunction ( accessor ) , indexName ) ;
limit = value instanceof Limit l ? l : Limit . of ( ( ( Number ) value ) . intValue ( ) ) ;
}
} else if ( this . limit . isLimited ( ) ) {
limit = this . limit ;
} else {
limit = accessor . getLimit ( ) ;
}
if ( limit . isLimited ( ) ) {
@SuppressWarnings ( "NullAway" )
query . query ( ) . limit ( limit ) ;
AggregationPipeline createVectorSearchPipeline ( VectorSearchInput input , String scoreField , Class < ? > outputType ,
}
MongoParameterAccessor accessor , ValueExpressionEvaluator evaluator ) {
Vector vector = accessor . getVector ( ) ;
Score score = accessor . getScore ( ) ;
Range < Score > distance = accessor . getScoreRange ( ) ;
Limit limit = Limit . of ( input . query ( ) . getLimit ( ) ) ;
List < AggregationOperation > stages = new ArrayList < > ( ) ;
VectorSearchOperation $vectorSearch = Aggregation . vectorSearch ( indexName ) . path ( input . path ( ) ) . vector ( vector )
. limit ( limit ) ;
Integer candidates = null ;
if ( this . numCandidatesExpression ! = null ) {
if ( this . numCandidatesExpression ! = null ) {
numCandidates = ( ( Number ) evaluator . evaluate ( this . numCandidatesExpression ) ) . intValue ( ) ;
c andidates = ( ( Number ) evaluator . evaluate ( this . numCandidatesExpression ) ) . intValue ( ) ;
} else if ( this . numCandidates ! = null ) {
} else if ( this . numCandidates ! = null ) {
numCandidates = this . numCandidates ;
c andidates = this . numCandidates ;
} else if ( query . query ( ) . isLimited ( ) & & ( searchType = = VectorSearchOperation . SearchType . ANN
} else if ( input . query ( ) . isLimited ( ) & & ( searchType = = VectorSearchOperation . SearchType . ANN
| | searchType = = VectorSearchOperation . SearchType . DEFAULT ) ) {
| | searchType = = VectorSearchOperation . SearchType . DEFAULT ) ) {
/ *
/ *
MongoDB : We recommend that you specify a number at least 20 times higher than the number of documents to return ( limit ) to increase accuracy .
MongoDB : We recommend that you specify a number at least 20 times higher than the number of documents to return ( limit ) to increase accuracy .
* /
* /
numCandidates = query . query ( ) . getLimit ( ) * 20 ;
candidates = input . query ( ) . getLimit ( ) * 20 ;
}
}
return new QueryMetadata ( query . path , "__score__" , query . query , searchType , outputType , numCandidates ,
if ( candidates ! = null ) {
getSimilarityFunction ( accessor ) ) ;
$vectorSearch = $vectorSearch . numCandidates ( candidates ) ;
}
//
$vectorSearch = $vectorSearch . filter ( input . query . getQueryObject ( ) ) ;
$vectorSearch = $vectorSearch . searchType ( this . searchType ) ;
$vectorSearch = $vectorSearch . withSearchScore ( scoreField ) ;
if ( score ! = null ) {
$vectorSearch = $vectorSearch . withFilterBySore ( c - > {
c . gt ( score . getValue ( ) ) ;
} ) ;
} else if ( distance . getLowerBound ( ) . isBounded ( ) | | distance . getUpperBound ( ) . isBounded ( ) ) {
$vectorSearch = $vectorSearch . withFilterBySore ( c - > {
Range . Bound < Score > lower = distance . getLowerBound ( ) ;
if ( lower . isBounded ( ) ) {
double value = lower . getValue ( ) . get ( ) . getValue ( ) ;
if ( lower . isInclusive ( ) ) {
c . gte ( value ) ;
} else {
c . gt ( value ) ;
}
}
Range . Bound < Score > upper = distance . getUpperBound ( ) ;
if ( upper . isBounded ( ) ) {
double value = upper . getValue ( ) . get ( ) . getValue ( ) ;
if ( upper . isInclusive ( ) ) {
c . lte ( value ) ;
} else {
c . lt ( value ) ;
}
}
} ) ;
}
stages . add ( $vectorSearch ) ;
if ( input . query ( ) . isSorted ( ) ) {
stages . add ( ctx - > {
Document mappedSort = ctx . getMappedObject ( input . query ( ) . getSortObject ( ) , outputType ) ;
mappedSort . append ( scoreField , - 1 ) ;
return ctx . getMappedObject ( new Document ( "$sort" , mappedSort ) ) ;
} ) ;
} else {
stages . add ( Aggregation . sort ( Sort . Direction . DESC , scoreField ) ) ;
}
return new AggregationPipeline ( stages ) ;
}
private VectorSearchInput createSearchInput ( ValueExpressionEvaluator evaluator , MongoParameterAccessor accessor ,
ParameterBindingDocumentCodec codec , ParameterBindingContext context ) {
VectorSearchInput input = queryFactory . createQuery ( accessor , codec , context ) ;
Limit limit = getLimit ( evaluator , accessor ) ;
if ( ! input . query . isLimited ( ) | | ( input . query . isLimited ( ) & & ! limit . isUnlimited ( ) ) ) {
input . query ( ) . limit ( limit ) ;
}
return input ;
}
private Limit getLimit ( ValueExpressionEvaluator evaluator , MongoParameterAccessor accessor ) {
if ( this . limitExpression ! = null ) {
Object value = evaluator . evaluate ( this . limitExpression ) ;
if ( value ! = null ) {
if ( value instanceof Limit l ) {
return l ;
}
if ( value instanceof Number n ) {
return Limit . of ( n . intValue ( ) ) ;
}
if ( value instanceof String s ) {
return Limit . of ( NumberUtils . parseNumber ( s , Integer . class ) ) ;
}
throw new IllegalArgumentException ( "Invalid type for Limit. Found [%s], expected Limit or Number" ) ;
}
}
if ( this . limit . isLimited ( ) ) {
return this . limit ;
}
return accessor . getLimit ( ) ;
}
}
public String getQueryString ( ) {
public String getQueryString ( ) {
@ -192,82 +289,10 @@ class VectorSearchDelegate {
* @param query
* @param query
* @param searchType
* @param searchType
* @param outputType
* @param outputType
* @param numCandidates
* @param scoringFunction
* @param scoringFunction
* /
* /
public record QueryMetadata ( String path , String scoreField , Query query , VectorSearchOperation . SearchType searchType ,
record QueryContainer ( String path , String scoreField , Query query , AggregationPipeline pipeline ,
Class < ? > outputType , @Nullable Integer numCandidates , ScoringFunction scoringFunction ) {
VectorSearchOperation . SearchType searchType , Class < ? > outputType , ScoringFunction scoringFunction , String index ) {
/ * *
* Create the Aggregation Pipeline .
*
* @param queryMethod
* @param accessor
* @return
* /
public List < AggregationOperation > getAggregationPipeline ( MongoQueryMethod queryMethod ,
MongoParameterAccessor accessor ) {
Vector vector = accessor . getVector ( ) ;
Score score = accessor . getScore ( ) ;
Range < Score > distance = accessor . getScoreRange ( ) ;
Limit limit = Limit . unlimited ( ) ;
if ( query . isLimited ( ) ) {
limit = Limit . of ( query . getLimit ( ) ) ;
}
List < AggregationOperation > stages = new ArrayList < > ( ) ;
VectorSearchOperation $vectorSearch = Aggregation . vectorSearch ( queryMethod . getAnnotatedHint ( ) ) . path ( path ( ) )
. vector ( vector ) . limit ( limit ) ;
if ( numCandidates ( ) ! = null ) {
$vectorSearch = $vectorSearch . numCandidates ( numCandidates ( ) ) ;
}
$vectorSearch = $vectorSearch . filter ( query . getQueryObject ( ) ) ;
$vectorSearch = $vectorSearch . searchType ( searchType ( ) ) ;
$vectorSearch = $vectorSearch . withSearchScore ( scoreField ( ) ) ;
if ( score ! = null ) {
$vectorSearch = $vectorSearch . withFilterBySore ( c - > {
c . gt ( score . getValue ( ) ) ;
} ) ;
} else if ( distance . getLowerBound ( ) . isBounded ( ) | | distance . getUpperBound ( ) . isBounded ( ) ) {
$vectorSearch = $vectorSearch . withFilterBySore ( c - > {
Range . Bound < Score > lower = distance . getLowerBound ( ) ;
if ( lower . isBounded ( ) ) {
double value = lower . getValue ( ) . get ( ) . getValue ( ) ;
if ( lower . isInclusive ( ) ) {
c . gte ( value ) ;
} else {
c . gt ( value ) ;
}
}
Range . Bound < Score > upper = distance . getUpperBound ( ) ;
if ( upper . isBounded ( ) ) {
double value = upper . getValue ( ) . get ( ) . getValue ( ) ;
if ( upper . isInclusive ( ) ) {
c . lte ( value ) ;
} else {
c . lt ( value ) ;
}
}
} ) ;
}
stages . add ( $vectorSearch ) ;
if ( query . isSorted ( ) ) {
// TODO stages.add(Aggregation.sort(query.with()));
} else {
stages . add ( Aggregation . sort ( Sort . Direction . DESC , "__score__" ) ) ;
}
return stages ;
}
}
}
@ -368,11 +393,12 @@ class VectorSearchDelegate {
this . tree = tree ;
this . tree = tree ;
}
}
@SuppressWarnings ( "NullAway" )
public VectorSearchInput createQuery ( MongoParameterAccessor parameterAccessor , ParameterBindingDocumentCodec codec ,
public VectorSearchInput createQuery ( MongoParameterAccessor parameterAccessor , ParameterBindingDocumentCodec codec ,
ParameterBindingContext context ) {
ParameterBindingContext context ) {
MongoQueryCreator creator = new MongoQueryCreator ( tree , parameterAccessor , converter . getMappingContext ( ) ,
MongoQueryCreator creator = new MongoQueryCreator ( tree , parameterAccessor , converter . getMappingContext ( ) , false ,
false , true ) ;
true ) ;
Query query = creator . createQuery ( parameterAccessor . getSort ( ) ) ;
Query query = creator . createQuery ( parameterAccessor . getSort ( ) ) ;