@ -15,7 +15,6 @@
@@ -15,7 +15,6 @@
* /
package org.springframework.data.mongodb.repository.aot ;
import java.lang.reflect.Field ;
import java.util.LinkedHashMap ;
import java.util.List ;
import java.util.Map ;
@ -25,15 +24,14 @@ import org.springframework.core.annotation.MergedAnnotation;
@@ -25,15 +24,14 @@ import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.data.domain.Limit ;
import org.springframework.data.domain.ScoringFunction ;
import org.springframework.data.domain.Sort ;
import org.springframework.data.domain.Vector ;
import org.springframework.data.mongodb.core.MongoOperations ;
import org.springframework.data.mongodb.core.aggregation.Aggregation ;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation ;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext ;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline ;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation ;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType ;
import org.springframework.data.mongodb.repository.VectorSearch ;
import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleBuilder ;
import org.springframework.data.mongodb.repository.query.MongoQueryExecution.VectorSearchExecution ;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod ;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext ;
@ -55,11 +53,14 @@ class VectorSearchBocks {
@@ -55,11 +53,14 @@ class VectorSearchBocks {
private String searchQueryVariableName ;
private StringQuery filter ;
private final Map < String , CodeBlock > arguments ;
private final String searchPath ;
VectorSearchQueryCodeBlockBuilder ( AotQueryMethodGenerationContext context , MongoQueryMethod queryMethod ) {
VectorSearchQueryCodeBlockBuilder ( AotQueryMethodGenerationContext context , MongoQueryMethod queryMethod ,
String searchPath ) {
this . context = context ;
this . queryMethod = queryMethod ;
this . searchPath = searchPath ;
this . arguments = new LinkedHashMap < > ( ) ;
context . getBindableParameterNames ( ) . forEach ( it - > arguments . put ( it , CodeBlock . of ( it ) ) ) ;
}
@ -77,135 +78,173 @@ class VectorSearchBocks {
@@ -77,135 +78,173 @@ class VectorSearchBocks {
String vectorParameterName = context . getVectorParameterName ( ) ;
MergedAnnotation < VectorSearch > annotation = context . getAnnotation ( VectorSearch . class ) ;
String searchPath = annotation . getString ( "path" ) ;
String indexName = annotation . getString ( "indexName" ) ;
String numCandidates = annotation . getString ( "numCandidates" ) ;
SearchType searchType = annotation . getEnum ( "searchType" , SearchType . class ) ;
String limit = annotation . getString ( "limit" ) ;
if ( ! StringUtils . hasText ( searchPath ) ) { // FIXME: somehow duplicate logic of AnnotatedQueryFactory
ExpressionSnippet limit = getLimitExpression ( ) ;
Field [ ] declaredFields = context . getRepositoryInformation ( ) . getDomainType ( ) . getDeclaredFields ( ) ;
for ( Field field : declaredFields ) {
if ( Vector . class . isAssignableFrom ( field . getType ( ) ) ) {
searchPath = field . getName ( ) ;
break ;
}
}
if ( limit . requiresEvaluation ( ) & & ! StringUtils . hasText ( annotation . getString ( "numCandidates" ) )
& & ( searchType = = VectorSearchOperation . SearchType . ANN
| | searchType = = VectorSearchOperation . SearchType . DEFAULT ) ) {
VariableSnippet variableBlock = limit . as ( VariableSnippet : : create )
. variableName ( context . localVariable ( "limitToUse" ) ) ;
variableBlock . renderDeclaration ( builder ) ;
limit = variableBlock ;
}
String vectorSearchVar = context . localVariable ( "$vectorSearch" ) ;
builder . add ( "$T $L = $T.vectorSearch($S).path($S).vector($L)" , VectorSearchOperation . class , vectorSearchVar ,
Aggregation . class , indexName , searchPath , vectorParameterName ) ;
BuilderStyleBuilder vectorSearchOperationBuilder = Snippet . declare ( builder )
. variableBuilder ( VectorSearchOperation . class , context . localVariable ( "$vectorSearch" ) )
. as ( "$T.vectorSearch($S).path($S).vector($L).limit($L)" , Aggregation . class , indexName , searchPath ,
vectorParameterName , limit . code ( ) ) ;
if ( StringUtils . hasText ( context . getLimitParameterName ( ) ) ) {
builder . add ( ".limit($L);\n" , context . getLimitParameterName ( ) ) ;
} else if ( filter . isLimited ( ) ) {
builder . add ( ".limit($L);\n" , filter . getLimit ( ) ) ;
} else if ( StringUtils . hasText ( limit ) ) {
if ( MongoCodeBlocks . containsPlaceholder ( limit ) | | MongoCodeBlocks . containsExpression ( limit ) ) {
builder . add ( ".limit(" ) ;
builder . add ( MongoCodeBlocks . evaluateNumberPotentially ( limit , Integer . class , arguments ) ) ;
builder . add ( ");\n" ) ;
} else {
builder . add ( ".limit($L);\n" , limit ) ;
if ( ! searchType . equals ( SearchType . DEFAULT ) ) {
vectorSearchOperationBuilder . call ( "searchType" ) . with ( "$T.$L" , SearchType . class , searchType . name ( ) ) ;
}
} else {
builder . add ( ".limit($T.unlimited());\n" , Limit . class ) ;
ExpressionSnippet numCandidates = getNumCandidatesExpression ( searchType , limit ) ;
if ( ! numCandidates . isEmpty ( ) ) {
vectorSearchOperationBuilder . call ( "numCandidates" ) . with ( numCandidates ) ;
}
if ( ! searchType . equals ( SearchType . DEFAULT ) ) {
builder . addStatement ( "$1L = $1L.searchType($2T.$3L)" , vectorSearchVar , SearchType . class , searchType . name ( ) ) ;
vectorSearchOperationBuilder . call ( "withSearchScore" ) . with ( "\"__score__\"" ) ;
if ( StringUtils . hasText ( context . getScoreParameterName ( ) ) ) {
vectorSearchOperationBuilder . call ( "withFilterBySore" ) . with ( "$1L -> { $1L.gt($2L.getValue()); }" ,
context . localVariable ( "criteria" ) , context . getScoreParameterName ( ) ) ;
} else if ( StringUtils . hasText ( context . getScoreRangeParameterName ( ) ) ) {
vectorSearchOperationBuilder . call ( "withFilterBySore" )
. with ( "scoreBetween($1L.getLowerBound(), $1L.getUpperBound())" , context . getScoreRangeParameterName ( ) ) ;
}
if ( StringUtils . hasText ( numCandidates ) ) {
builder . add ( "$1L = $1L.numCandidates(" , vectorSearchVar ) ;
builder . add ( MongoCodeBlocks . evaluateNumberPotentially ( numCandidates , Integer . class , arguments ) ) ;
builder . add ( ");\n" ) ;
} else if ( searchType = = VectorSearchOperation . SearchType . ANN
| | searchType = = VectorSearchOperation . SearchType . DEFAULT ) {
VariableSnippet vectorSearchOperation = vectorSearchOperationBuilder . variable ( ) ;
getFilter ( vectorSearchOperation . getVariableName ( ) ) . appendTo ( builder ) ;
builder . add (
"// MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return\n" ) ;
if ( StringUtils . hasText ( context . getLimitParameterName ( ) ) ) {
builder . addStatement ( "$1L = $1L.numCandidates($2L.max() * 20)" , vectorSearchVar ,
context . getLimitParameterName ( ) ) ;
} else if ( StringUtils . hasText ( limit ) ) {
if ( MongoCodeBlocks . containsPlaceholder ( limit ) | | MongoCodeBlocks . containsExpression ( limit ) ) {
VariableSnippet sortStage = getSort ( ) . as ( VariableSnippet : : create ) . variableName ( context . localVariable ( "$sort" ) ) ;
sortStage . renderDeclaration ( builder ) ;
builder . add ( "$1L = $1L.numCandidates((" , vectorSearchVar ) ;
builder . add ( MongoCodeBlocks . evaluateNumberPotentially ( limit , Integer . class , arguments ) ) ;
builder . add ( ") * 20);\n" ) ;
builder . add ( "\n" ) ;
VariableSnippet aggregationPipeline = Snippet . declare ( builder )
. variable ( AggregationPipeline . class , searchQueryVariableName ) . as ( "new $T($T.of($L, $L))" ,
AggregationPipeline . class , List . class , vectorSearchOperation . getVariableName ( ) , sortStage . code ( ) ) ;
String scoringFunctionVar = context . localVariable ( "scoringFunction" ) ;
builder . add ( "$1T $2L = " , ScoringFunction . class , scoringFunctionVar ) ;
if ( StringUtils . hasText ( context . getScoreParameterName ( ) ) ) {
builder . add ( "$L.getFunction();\n" , context . getScoreParameterName ( ) ) ;
} else if ( StringUtils . hasText ( context . getScoreRangeParameterName ( ) ) ) {
builder . add ( "scoringFunction($L);\n" , context . getScoreRangeParameterName ( ) ) ;
} else {
builder . addStatement ( "$1L = $1L.numCandidates($2L * 20)" , vectorSearchVar , limit ) ;
builder . add ( "$1T.unspecified();\n" , ScoringFunction . class ) ;
}
} else {
builder . addStatement ( "$1L = $1L.numCandidates($2L)" , vectorSearchVar , filter . getLimit ( ) * 20 ) ;
builder . addStatement (
"return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)" ,
VectorSearchExecution . class , context . fieldNameOf ( MongoOperations . class ) ,
context . getRepositoryInformation ( ) . getDomainType ( ) , TypeInformation . class ,
queryMethod . getReturnType ( ) . getType ( ) , aggregationPipeline . getVariableName ( ) , scoringFunctionVar ) ;
return builder . build ( ) ;
}
private ExpressionSnippet getSort ( ) {
if ( ! filter . isSorted ( ) ) {
return new ExpressionSnippet (
CodeBlock . of ( "$T.sort($T.Direction.DESC, $S)" , Aggregation . class , Sort . class , "__score__" ) ) ;
}
builder . addStatement ( "$1L = $1L.withSearchScore(\"__score__\")" , vectorSearchVar ) ;
if ( StringUtils . hasText ( context . getScoreParameterName ( ) ) ) {
Builder builder = CodeBlock . builder ( ) ;
String scoreCriteriaVar = context . localVariable ( "criteria" ) ;
builder . addStatement ( "$1L = $1L.withFilterBySore($2L -> { $2L.gt($3L.getValue()); })" , vectorSearchVar ,
scoreCriteriaVar , context . getScoreParameterName ( ) ) ;
} else if ( StringUtils . hasText ( context . getScoreRangeParameterName ( ) ) ) {
builder . addStatement ( "$1L = $1L.withFilterBySore(scoreBetween($2L.getLowerBound(), $2L.getUpperBound()))" ,
vectorSearchVar , context . getScoreRangeParameterName ( ) ) ;
builder . add ( "($T) (_ctx) -> {\n" , AggregationOperation . class ) ;
builder . indent ( ) ;
builder . add ( "$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class);\n" , Document . class ,
filter . getSortString ( ) , context . getActualReturnType ( ) . getType ( ) ) ;
builder . add ( "return new $T($S, _mappedSort.append(\"__score__\", -1));\n" , Document . class , "$sort" ) ;
builder . unindent ( ) ;
builder . add ( "};" ) ;
return new ExpressionSnippet ( builder . build ( ) ) ;
}
if ( StringUtils . hasText ( filter . getQueryString ( ) ) ) {
private Snippet getFilter ( String vectorSearchVar ) {
if ( ! StringUtils . hasText ( filter . getQueryString ( ) ) ) {
return ExpressionSnippet . empty ( ) ;
}
Builder builder = CodeBlock . builder ( ) ;
String filterVar = context . localVariable ( "filter" ) ;
builder . add ( MongoCodeBlocks . queryBlockBuilder ( context , queryMethod ) . usingQueryVariableName ( "filter" )
. filter ( new QueryInteraction ( this . filter , false , false , false ) ) . buildJustTheQuery ( ) ) ;
builder . addStatement ( "$1L = $1L.filter($2L.getQueryObject())" , vectorSearchVar , filterVar ) ;
builder . add ( "\n" ) ;
}
return new ExpressionSnippet ( builder . build ( ) ) ;
}
String sortStageVar = context . localVariable ( "$sort" ) ;
if ( filter . isSorted ( ) ) {
public VectorSearchQueryCodeBlockBuilder withFilter ( StringQuery filter ) {
this . filter = filter ;
return this ;
}
builder . add ( "$T $L = (_ctx) -> {\n" , AggregationOperation . class , sortStageVar ) ;
builder . indent ( ) ;
private ExpressionSnippet getNumCandidatesExpression ( SearchType searchType , ExpressionSnippet limit ) {
builder . addStatement ( "$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class)" , Document . class , filter . getSortString ( ) , context . getActualReturnType ( ) . getType ( ) ) ;
builder . addStatement ( "return new $T($S, _mappedSort.append(\"__score__\", -1))" , Document . class , "$sort" ) ;
builder . unindent ( ) ;
builder . add ( "};" ) ;
MergedAnnotation < VectorSearch > annotation = context . getAnnotation ( VectorSearch . class ) ;
String numCandidates = annotation . getString ( "numCandidates" ) ;
if ( StringUtils . hasText ( numCandidates ) ) {
if ( MongoCodeBlocks . containsPlaceholder ( numCandidates ) | | MongoCodeBlocks . containsExpression ( numCandidates ) ) {
return new ExpressionSnippet (
MongoCodeBlocks . evaluateNumberPotentially ( numCandidates , Integer . class , arguments ) , true ) ;
} else {
builder . addStatement ( "var $L = $T.sort($T.Direction.DESC, $S)" , sortStageVar , Aggregation . class , Sort . class , "__score__" ) ;
return new ExpressionSnippet ( CodeBlock . of ( "$L" , numCandidates ) ) ;
}
}
builder . add ( "\n" ) ;
builder . addStatement ( "$1T $2L = new $1T($3T.of($4L, $5L))" , AggregationPipeline . class , searchQueryVariableName ,
List . class , vectorSearchVar , sortStageVar ) ;
if ( searchType = = VectorSearchOperation . SearchType . ANN
| | searchType = = VectorSearchOperation . SearchType . DEFAULT ) {
String scoringFunctionVar = context . localVariable ( "scoringFunction" ) ;
builder . add ( "$1T $2L = " , ScoringFunction . class , scoringFunctionVar ) ;
if ( StringUtils . hasText ( context . getScore ParameterName ( ) ) ) {
builder . add ( "$L.getFunction();\n " , context . getScore ParameterName ( ) ) ;
} else if ( StringUtils . hasText ( context . getScoreRangeParameterName ( ) ) ) {
builder . add ( "scoringFunction( $L);\n " , context . getScoreRangeParameterName ( ) ) ;
Builder builder = CodeBlock . builder ( ) ;
if ( StringUtils . hasText ( context . getLimit ParameterName ( ) ) ) {
builder . add ( "$L.max() * 20 " , context . getLimit ParameterName ( ) ) ;
} else if ( filter . isLimited ( ) ) {
builder . add ( "$L" , filter . getLimit ( ) * 20 ) ;
} else {
builder . add ( "$1T.unspecified();\n" , ScoringFunction . class ) ;
builder . add ( "$L * 20" , limit . code ( ) ) ;
}
builder . addStatement (
"return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)" ,
VectorSearchExecution . class , context . fieldNameOf ( MongoOperations . class ) ,
context . getRepositoryInformation ( ) . getDomainType ( ) , TypeInformation . class ,
queryMethod . getReturnType ( ) . getType ( ) , searchQueryVariableName , scoringFunctionVar ) ;
return builder . build ( ) ;
return new ExpressionSnippet ( builder . build ( ) ) ;
}
public VectorSearchQueryCodeBlockBuilder withFilter ( StringQuery filter ) {
this . filter = filter ;
return this ;
return ExpressionSnippet . empty ( ) ;
}
private ExpressionSnippet getLimitExpression ( ) {
if ( StringUtils . hasText ( context . getLimitParameterName ( ) ) ) {
return new ExpressionSnippet ( CodeBlock . of ( "$L" , context . getLimitParameterName ( ) ) ) ;
}
if ( filter . isLimited ( ) ) {
return new ExpressionSnippet ( CodeBlock . of ( "$L" , filter . getLimit ( ) ) ) ;
}
MergedAnnotation < VectorSearch > annotation = context . getAnnotation ( VectorSearch . class ) ;
String limit = annotation . getString ( "limit" ) ;
if ( StringUtils . hasText ( limit ) ) {
if ( MongoCodeBlocks . containsPlaceholder ( limit ) | | MongoCodeBlocks . containsExpression ( limit ) ) {
return new ExpressionSnippet ( MongoCodeBlocks . evaluateNumberPotentially ( limit , Integer . class , arguments ) ,
true ) ;
} else {
return new ExpressionSnippet ( CodeBlock . of ( "$L" , limit ) ) ;
}
}
return new ExpressionSnippet ( CodeBlock . of ( "$T.unlimited()" , Limit . class ) ) ;
}
}
}