diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index db3dcc126..ab03b4142 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -486,7 +486,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, return doStream(query, entityType, collectionName, returnType, QueryResultConverter.entity()); } - @SuppressWarnings("ConstantConditions") + @SuppressWarnings({"ConstantConditions", "NullAway"}) Stream doStream(Query query, Class entityType, String collectionName, Class returnType, QueryResultConverter resultConverter) { @@ -1086,34 +1086,29 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, return new GeoResults<>(result, avgDistance); } - @Nullable - @Override - public T findAndModify(Query query, UpdateDefinition update, Class entityClass) { + public @Nullable T findAndModify(Query query, UpdateDefinition update, Class entityClass) { return findAndModify(query, update, new FindAndModifyOptions(), entityClass, getCollectionName(entityClass)); } - @Nullable @Override - public T findAndModify(Query query, UpdateDefinition update, Class entityClass, + public @Nullable T findAndModify(Query query, UpdateDefinition update, Class entityClass, String collectionName) { return findAndModify(query, update, new FindAndModifyOptions(), entityClass, collectionName); } - @Nullable @Override - public T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options, + public @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options, Class entityClass) { return findAndModify(query, update, options, entityClass, getCollectionName(entityClass)); } - @Nullable @Override - public T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options, + public @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options, Class entityClass, String collectionName) { return findAndModify(query, update, options, entityClass, collectionName, QueryResultConverter.entity()); } - T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options, + @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options, Class entityClass, String collectionName, QueryResultConverter resultConverter) { Assert.notNull(query, "Query must not be null"); @@ -1185,15 +1180,13 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, // Find methods that take a Query to express the query and that return a single object that is also removed from the // collection in the database. - @Nullable @Override - public T findAndRemove(Query query, Class entityClass) { + public @Nullable T findAndRemove(Query query, Class entityClass) { return findAndRemove(query, entityClass, getCollectionName(entityClass)); } - @Nullable @Override - public T findAndRemove(Query query, Class entityClass, String collectionName) { + public @Nullable T findAndRemove(Query query, Class entityClass, String collectionName) { Assert.notNull(query, "Query must not be null"); Assert.notNull(entityClass, "EntityClass must not be null"); @@ -2161,11 +2154,11 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, * @param entityClass * @return */ - @SuppressWarnings("NullAway") protected List doFindAndDelete(String collectionName, Query query, Class entityClass) { return doFindAndDelete(collectionName, query, entityClass, QueryResultConverter.entity()); } + @SuppressWarnings("NullAway") List doFindAndDelete(String collectionName, Query query, Class entityClass, QueryResultConverter resultConverter) { @@ -2229,7 +2222,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, return doAggregate(aggregation, collectionName, outputType, QueryResultConverter.entity(), context); } - @SuppressWarnings("ConstantConditions") + @SuppressWarnings({"ConstantConditions", "NullAway"}) AggregationResults doAggregate(Aggregation aggregation, String collectionName, Class outputType, QueryResultConverter resultConverter, AggregationOperationContext context) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java index 935c55fc9..0ad473b8b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java @@ -2293,6 +2293,7 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati .flatMapSequential(deleteResult -> Flux.fromIterable(list))); } + @SuppressWarnings({"rawtypes", "unchecked", "NullAway"}) Flux doFindAndDelete(String collectionName, Query query, Class entityClass, QueryResultConverter resultConverter) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java index 40966bcf3..f06803997 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java @@ -22,8 +22,10 @@ import java.util.List; import java.util.function.Predicate; import org.bson.Document; +import org.jspecify.annotations.Nullable; import org.springframework.lang.Contract; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * The {@link AggregationPipeline} holds the collection of {@link AggregationOperation aggregation stages}. @@ -82,6 +84,14 @@ public class AggregationPipeline { return Collections.unmodifiableList(pipeline); } + public @Nullable AggregationOperation firstOperation() { + return CollectionUtils.firstElement(pipeline); + } + + public @Nullable AggregationOperation lastOperation() { + return CollectionUtils.lastElement(pipeline); + } + List toDocuments(AggregationOperationContext context) { verify(); @@ -97,8 +107,8 @@ public class AggregationPipeline { return false; } - AggregationOperation operation = pipeline.get(pipeline.size() - 1); - return isOut(operation) || isMerge(operation); + AggregationOperation operation = lastOperation(); + return operation != null && (isOut(operation) || isMerge(operation)); } void verify() { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java index 02b805d5e..85952d8f3 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java @@ -356,6 +356,7 @@ public class ArrayOperators { * @return new instance of {@link SortArray}. * @since 4.5 */ + @SuppressWarnings("NullAway") public SortArray sort(Direction direction) { if (usesFieldRef()) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java index e51d4435a..f203b67e6 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java @@ -77,7 +77,7 @@ public class ConvertingParameterAccessor implements MongoParameterAccessor { } @Override - public Vector getVector() { + public @Nullable Vector getVector() { return delegate.getVector(); } @@ -104,12 +104,12 @@ public class ConvertingParameterAccessor implements MongoParameterAccessor { } @Override - public @org.jspecify.annotations.Nullable Score getScore() { + public @Nullable Score getScore() { return delegate.getScore(); } @Override - public @org.jspecify.annotations.Nullable Range getScoreRange() { + public @Nullable Range getScoreRange() { return delegate.getScoreRange(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java index 41cf084d4..0f5622349 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java @@ -61,14 +61,13 @@ public class MongoParametersParameterAccessor extends ParametersParameterAccesso public Range getScoreRange() { MongoParameters mongoParameters = method.getParameters(); - int rangeIndex = mongoParameters.getScoreRangeIndex(); - if (rangeIndex != -1) { - return getValue(rangeIndex); + if (mongoParameters.hasScoreRangeParameter()) { + return getValue(mongoParameters.getScoreRangeIndex()); } - int scoreIndex = mongoParameters.getScoreIndex(); - Bound maxDistance = scoreIndex == -1 ? Bound.unbounded() : Bound.inclusive((Score) getScore()); + Score score = getScore(); + Bound maxDistance = score != null ? Bound.inclusive(score) : Bound.unbounded(); return Range.of(Bound.unbounded(), maxDistance); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java index 1f742ec32..ba7394ec1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java @@ -15,7 +15,8 @@ */ package org.springframework.data.mongodb.repository.query; -import static org.springframework.data.mongodb.core.query.Criteria.*; +import static org.springframework.data.mongodb.core.query.Criteria.Placeholder; +import static org.springframework.data.mongodb.core.query.Criteria.where; import java.util.Arrays; import java.util.Collection; @@ -27,7 +28,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.bson.BsonRegularExpression; import org.jspecify.annotations.Nullable; - import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; import org.springframework.data.domain.Sort; @@ -118,8 +118,9 @@ public class MongoQueryCreator extends AbstractQueryCreator { return new Criteria(); } - if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { - return null; + if (isPartOfSearchQuery(part)) { + skip(part, iterator); + return new Criteria(); } PersistentPropertyPath path = context.getPersistentPropertyPath(part.getProperty()); @@ -135,7 +136,8 @@ public class MongoQueryCreator extends AbstractQueryCreator { return create(part, iterator); } - if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { + if (isPartOfSearchQuery(part)) { + skip(part, iterator); return base; } @@ -176,15 +178,6 @@ public class MongoQueryCreator extends AbstractQueryCreator { @SuppressWarnings("NullAway") private Criteria from(Part part, MongoPersistentProperty property, Criteria criteria, Iterator parameters) { - if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { - - int numberOfArguments = part.getType().getNumberOfArguments(); - for (int i = 0; i < numberOfArguments; i++) { - parameters.next(); - } - return null; - } - Type type = part.getType(); switch (type) { @@ -206,13 +199,13 @@ public class MongoQueryCreator extends AbstractQueryCreator { return criteria.is(null); case NOT_IN: Object ninValue = parameters.next(); - if(ninValue instanceof Placeholder) { + if (ninValue instanceof Placeholder) { return criteria.raw("$nin", ninValue); } return criteria.nin(valueAsList(ninValue, part)); case IN: Object inValue = parameters.next(); - if(inValue instanceof Placeholder) { + if (inValue instanceof Placeholder) { return criteria.raw("$in", inValue); } return criteria.in(valueAsList(inValue, part)); @@ -231,7 +224,7 @@ public class MongoQueryCreator extends AbstractQueryCreator { return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString()); case EXISTS: Object next = parameters.next(); - if(next instanceof Placeholder placeholder) { + if (next instanceof Placeholder placeholder) { return criteria.raw("$exists", placeholder); } else { return criteria.exists((Boolean) next); @@ -355,7 +348,7 @@ public class MongoQueryCreator extends AbstractQueryCreator { if (property.isCollectionLike()) { Object next = parameters.next(); - if(next instanceof Placeholder) { + if (next instanceof Placeholder) { return criteria.raw("$in", next); } return criteria.in(valueAsList(next, part)); @@ -433,8 +426,7 @@ public class MongoQueryCreator extends AbstractQueryCreator { streamable = streamable.map(it -> { if (it instanceof String sv) { - return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), - regexOptions); + return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), regexOptions); } return it; }); @@ -468,10 +460,23 @@ public class MongoQueryCreator extends AbstractQueryCreator { return false; } + private boolean isPartOfSearchQuery(Part part) { + return isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN)); + } + + private static void skip(Part part, Iterator parameters) { + + int total = part.getNumberOfArguments(); + int i = 0; + while (parameters.hasNext() && i < total) { + parameters.next(); + i++; + } + } + /** * Compute a {@link Type#BETWEEN} typed {@link Part} using {@link Criteria#gt(Object) $gt}, - * {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}. - *
+ * {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}.
* In case the first {@literal value} is actually a {@link Range} the lower and upper bounds of the {@link Range} are * used according to their {@link Bound#isInclusive() inclusion} definition. Otherwise the {@literal value} is used * for {@literal $gt} and {@link Iterator#next() parameters.next()} as {@literal $lt}. diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java index d9a91434c..c0531e0e1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java @@ -15,18 +15,16 @@ */ package org.springframework.data.mongodb.repository.query; -import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.function.Supplier; -import org.bson.Document; import org.jspecify.annotations.Nullable; - import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.SearchResult; import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Similarity; @@ -37,6 +35,7 @@ import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.Point; +import org.springframework.data.mongodb.core.ExecutableAggregationOperation.TerminatingAggregation; import org.springframework.data.mongodb.core.ExecutableFindOperation; import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; import org.springframework.data.mongodb.core.ExecutableFindOperation.TerminatingFind; @@ -45,12 +44,13 @@ import org.springframework.data.mongodb.core.ExecutableRemoveOperation.Executabl import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; import org.springframework.data.mongodb.core.MongoOperations; -import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.mongodb.repository.util.SliceUtils; import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.support.PageableExecutionUtils; @@ -186,7 +186,7 @@ public interface MongoQueryExecution { return isListOfGeoResult(method.getReturnType()) ? results.getContent() : results; } - @SuppressWarnings({"unchecked","NullAway"}) + @SuppressWarnings({ "unchecked", "NullAway" }) GeoResults doExecuteQuery(Query query) { Point nearLocation = accessor.getGeoNearLocation(); @@ -225,52 +225,53 @@ public interface MongoQueryExecution { * {@link MongoQueryExecution} to execute vector search. * * @author Mark Paluch + * @author Chistoph Strobl * @since 5.0 */ class VectorSearchExecution implements MongoQueryExecution { private final MongoOperations operations; - private final MongoQueryMethod method; + private final TypeInformation returnType; private final String collectionName; - private final VectorSearchDelegate.QueryMetadata queryMetadata; - private final List pipeline; + private final Class targetType; + private final ScoringFunction scoringFunction; + private final AggregationPipeline pipeline; + + VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName, + QueryContainer queryContainer) { + this(operations, queryContainer.outputType(), collectionName, method.getReturnType(), queryContainer.pipeline(), + queryContainer.scoringFunction()); + } - public VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName, - VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) { + public VectorSearchExecution(MongoOperations operations, Class targetType, String collectionName, + TypeInformation returnType, AggregationPipeline pipeline, ScoringFunction scoringFunction) { this.operations = operations; + this.returnType = returnType; this.collectionName = collectionName; - this.queryMetadata = queryMetadata; - this.method = method; - this.pipeline = queryMetadata.getAggregationPipeline(method, accessor); + this.targetType = targetType; + this.scoringFunction = scoringFunction; + this.pipeline = pipeline; } @Override + @SuppressWarnings({ "unchecked", "rawtypes" }) public Object execute(Query query) { - AggregationResults aggregated = operations.aggregate( - TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collectionName, - queryMetadata.outputType()); - - List mappedResults = aggregated.getMappedResults(); + TerminatingAggregation executableAggregation = operations.aggregateAndReturn(targetType) + .inCollection(collectionName).by(TypedAggregation.newAggregation(targetType, pipeline.getOperations())); - if (isSearchResult(method.getReturnType())) { - - List rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class); - List> result = new ArrayList<>(mappedResults.size()); - - for (int i = 0; i < mappedResults.size(); i++) { - Document document = rawResults.get(i); - SearchResult searchResult = new SearchResult<>(mappedResults.get(i), - Similarity.raw(document.getDouble("__score__"), queryMetadata.scoringFunction())); - - result.add(searchResult); - } - - return isListOfSearchResult(method.getReturnType()) ? result : new SearchResults<>(result); + if (!isSearchResult(returnType)) { + return executableAggregation.all().getMappedResults(); } - return mappedResults; + AggregationResults> result = executableAggregation + .map((raw, container) -> new SearchResult<>(container.get(), + Similarity.raw(raw.getDouble("__score__"), scoringFunction))) + .all(); + + return isListOfSearchResult(returnType) ? result.getMappedResults() + : new SearchResults(result.getMappedResults()); } private static boolean isListOfSearchResult(TypeInformation returnType) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java index 389f4e871..29e2127e1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java @@ -18,12 +18,9 @@ package org.springframework.data.mongodb.repository.query; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.List; - import org.bson.Document; import org.jspecify.annotations.Nullable; import org.reactivestreams.Publisher; - import org.springframework.core.convert.converter.Converter; import org.springframework.data.convert.DtoInstantiatingConverter; import org.springframework.data.domain.Pageable; @@ -36,11 +33,12 @@ import org.springframework.data.geo.Point; import org.springframework.data.mapping.model.EntityInstantiators; import org.springframework.data.mongodb.core.ReactiveMongoOperations; import org.springframework.data.mongodb.core.ReactiveUpdateOperation.ReactiveUpdate; -import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ReturnedType; import org.springframework.data.util.ReactiveWrappers; @@ -134,24 +132,24 @@ interface ReactiveMongoQueryExecution { class VectorSearchExecution implements ReactiveMongoQueryExecution { private final ReactiveMongoOperations operations; - private final VectorSearchDelegate.QueryMetadata queryMetadata; - private final List pipeline; + private final QueryContainer queryMetadata; + private final AggregationPipeline pipeline; private final boolean returnSearchResult; - public VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, - VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) { + VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, QueryContainer queryMetadata) { this.operations = operations; this.queryMetadata = queryMetadata; - this.pipeline = queryMetadata.getAggregationPipeline(method, accessor); + this.pipeline = queryMetadata.pipeline(); this.returnSearchResult = isSearchResult(method.getReturnType()); } @Override public Publisher execute(Query query, Class type, String collection) { - Flux aggregate = operations - .aggregate(TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collection, Document.class); + Flux aggregate = operations.aggregate( + TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline.getOperations()), collection, + Document.class); return aggregate.map(document -> { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java index 1ecbb0235..cf75c7db9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java @@ -19,13 +19,13 @@ import reactor.core.publisher.Mono; import org.bson.Document; import org.reactivestreams.Publisher; - import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.ReactiveMongoOperations; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ValueExpressionDelegate; @@ -84,11 +84,11 @@ public class ReactiveVectorSearchAggregation extends AbstractReactiveMongoQuery ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue, expressionEvaluator); - VectorSearchDelegate.QueryMetadata query = delegate.createQuery(expressionEvaluator, processor, accessor, - typeToRead, codec, bindingContext); + QueryContainer query = delegate.createQuery(expressionEvaluator, processor, accessor, typeToRead, codec, + bindingContext); ReactiveMongoQueryExecution.VectorSearchExecution execution = new ReactiveMongoQueryExecution.VectorSearchExecution( - mongoOperations, method, query, accessor); + mongoOperations, method, query); return execution.execute(query.query(), Document.class, collectionEntity.getCollection()); }); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java index 9740c0696..eb8dc2e52 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java @@ -15,16 +15,17 @@ */ package org.springframework.data.mongodb.repository.query; +import org.jspecify.annotations.Nullable; import org.springframework.data.mapping.model.ValueExpressionEvaluator; import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ValueExpressionDelegate; -import org.springframework.lang.Nullable; /** * {@link AbstractMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either derived @@ -62,20 +63,19 @@ public class VectorSearchAggregation extends AbstractMongoQuery { this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate); } - @SuppressWarnings("unchecked") @Override protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor, @Nullable Class typeToRead) { - VectorSearchDelegate.QueryMetadata query = createVectorSearchQuery(processor, accessor, typeToRead); + QueryContainer query = createVectorSearchQuery(processor, accessor, typeToRead); MongoQueryExecution.VectorSearchExecution execution = new MongoQueryExecution.VectorSearchExecution(mongoOperations, - method, collectionEntity.getCollection(), query, accessor); + method, collectionEntity.getCollection(), query); return execution.execute(query.query()); } - VectorSearchDelegate.QueryMetadata createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor, + QueryContainer createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor, @Nullable Class typeToRead) { ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java index 8932b85b1..0dbff2e93 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java @@ -20,7 +20,6 @@ import java.util.List; import org.bson.Document; import org.jspecify.annotations.Nullable; - import org.springframework.data.domain.Limit; import org.springframework.data.domain.Range; import org.springframework.data.domain.Score; @@ -34,6 +33,7 @@ import org.springframework.data.mapping.model.ValueExpressionEvaluator; import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; @@ -47,6 +47,7 @@ import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ValueExpressionDelegate; import org.springframework.data.repository.query.parser.Part; import org.springframework.data.repository.query.parser.PartTree; +import org.springframework.util.NumberUtils; import org.springframework.util.StringUtils; /** @@ -58,32 +59,35 @@ class VectorSearchDelegate { private final VectorSearchQueryFactory queryFactory; private final VectorSearchOperation.SearchType searchType; + private final String indexName; private final @Nullable Integer numCandidates; private final @Nullable String numCandidatesExpression; private final Limit limit; private final @Nullable String limitExpression; private final MongoConverter converter; - public VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) { + VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) { VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow(); + this.searchType = vectorSearch.searchType(); + this.indexName = method.getAnnotatedHint(); if (StringUtils.hasText(vectorSearch.numCandidates())) { ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates()); if (expression.isLiteral()) { - numCandidates = Integer.parseInt(vectorSearch.numCandidates()); - numCandidatesExpression = null; + this.numCandidates = Integer.parseInt(vectorSearch.numCandidates()); + this.numCandidatesExpression = null; } else { - numCandidates = null; - numCandidatesExpression = vectorSearch.numCandidates(); + this.numCandidates = null; + this.numCandidatesExpression = vectorSearch.numCandidates(); } } else { - numCandidates = null; - numCandidatesExpression = null; + this.numCandidates = null; + this.numCandidatesExpression = null; } if (StringUtils.hasText(vectorSearch.limit())) { @@ -91,26 +95,26 @@ class VectorSearchDelegate { ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit()); if (expression.isLiteral()) { - limit = Limit.of(Integer.parseInt(vectorSearch.limit())); - limitExpression = null; + this.limit = Limit.of(Integer.parseInt(vectorSearch.limit())); + this.limitExpression = null; } else { - limit = Limit.unlimited(); - limitExpression = vectorSearch.limit(); + this.limit = Limit.unlimited(); + this.limitExpression = vectorSearch.limit(); } } else { - limit = Limit.unlimited(); - limitExpression = null; + this.limit = Limit.unlimited(); + this.limitExpression = null; } this.converter = converter; if (StringUtils.hasText(vectorSearch.filter())) { - queryFactory = StringUtils.hasText(vectorSearch.path()) + this.queryFactory = StringUtils.hasText(vectorSearch.path()) ? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path()) : new AnnotatedQueryFactory(vectorSearch.filter(), method.getEntityInformation().getCollectionEntity()); } else { - queryFactory = new PartTreeQueryFactory( + this.queryFactory = new PartTreeQueryFactory( new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()), converter.getMappingContext()); } @@ -119,43 +123,136 @@ class VectorSearchDelegate { /** * Create Query Metadata for {@code $vectorSearch}. */ - public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor, + QueryContainer createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor, MongoParameterAccessor accessor, @Nullable Class typeToRead, ParameterBindingDocumentCodec codec, ParameterBindingContext context) { - Integer numCandidates = null; - Limit limit; + String scoreField = "__score__"; Class outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType(); - VectorSearchInput query = queryFactory.createQuery(accessor, codec, context); + VectorSearchInput vectorSearchInput = createSearchInput(evaluator, accessor, codec, context); + AggregationPipeline pipeline = createVectorSearchPipeline(vectorSearchInput, scoreField, outputType, accessor, + evaluator); - if (this.limitExpression != null) { - Object value = evaluator.evaluate(this.limitExpression); - limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue()); - } else if (this.limit.isLimited()) { - limit = this.limit; - } else { - limit = accessor.getLimit(); - } + return new QueryContainer(vectorSearchInput.path, scoreField, vectorSearchInput.query, pipeline, searchType, + outputType, getSimilarityFunction(accessor), indexName); + } - if (limit.isLimited()) { - query.query().limit(limit); - } + @SuppressWarnings("NullAway") + AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class outputType, + MongoParameterAccessor accessor, ValueExpressionEvaluator evaluator) { + + Vector vector = accessor.getVector(); + Score score = accessor.getScore(); + Range distance = accessor.getScoreRange(); + Limit limit = Limit.of(input.query().getLimit()); + + List stages = new ArrayList<>(); + VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(indexName).path(input.path()).vector(vector) + .limit(limit); + Integer candidates = null; if (this.numCandidatesExpression != null) { - numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); + candidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); } else if (this.numCandidates != null) { - numCandidates = this.numCandidates; - } else if (query.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN + candidates = this.numCandidates; + } else if (input.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN || searchType == VectorSearchOperation.SearchType.DEFAULT)) { /* MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy. */ - numCandidates = query.query().getLimit() * 20; + candidates = input.query().getLimit() * 20; } - return new QueryMetadata(query.path, "__score__", query.query, searchType, outputType, numCandidates, - getSimilarityFunction(accessor)); + if (candidates != null) { + $vectorSearch = $vectorSearch.numCandidates(candidates); + } + // + $vectorSearch = $vectorSearch.filter(input.query.getQueryObject()); + $vectorSearch = $vectorSearch.searchType(this.searchType); + $vectorSearch = $vectorSearch.withSearchScore(scoreField); + + if (score != null) { + $vectorSearch = $vectorSearch.withFilterBySore(c -> { + c.gt(score.getValue()); + }); + } else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) { + $vectorSearch = $vectorSearch.withFilterBySore(c -> { + Range.Bound lower = distance.getLowerBound(); + if (lower.isBounded()) { + double value = lower.getValue().get().getValue(); + if (lower.isInclusive()) { + c.gte(value); + } else { + c.gt(value); + } + } + + Range.Bound upper = distance.getUpperBound(); + if (upper.isBounded()) { + + double value = upper.getValue().get().getValue(); + if (upper.isInclusive()) { + c.lte(value); + } else { + c.lt(value); + } + } + }); + } + + stages.add($vectorSearch); + + if (input.query().isSorted()) { + + stages.add(ctx -> { + + Document mappedSort = ctx.getMappedObject(input.query().getSortObject(), outputType); + mappedSort.append(scoreField, -1); + return ctx.getMappedObject(new Document("$sort", mappedSort)); + }); + } else { + stages.add(Aggregation.sort(Sort.Direction.DESC, scoreField)); + } + + return new AggregationPipeline(stages); + } + + private VectorSearchInput createSearchInput(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor, + ParameterBindingDocumentCodec codec, ParameterBindingContext context) { + + VectorSearchInput input = queryFactory.createQuery(accessor, codec, context); + Limit limit = getLimit(evaluator, accessor); + if(!input.query.isLimited() || (input.query.isLimited() && !limit.isUnlimited())) { + input.query().limit(limit); + } + return input; + } + + private Limit getLimit(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor) { + + if (this.limitExpression != null) { + + Object value = evaluator.evaluate(this.limitExpression); + if (value != null) { + if (value instanceof Limit l) { + return l; + } + if (value instanceof Number n) { + return Limit.of(n.intValue()); + } + if (value instanceof String s) { + return Limit.of(NumberUtils.parseNumber(s, Integer.class)); + } + throw new IllegalArgumentException("Invalid type for Limit. Found [%s], expected Limit or Number"); + } + } + + if (this.limit.isLimited()) { + return this.limit; + } + + return accessor.getLimit(); } public String getQueryString() { @@ -192,82 +289,10 @@ class VectorSearchDelegate { * @param query * @param searchType * @param outputType - * @param numCandidates * @param scoringFunction */ - public record QueryMetadata(String path, String scoreField, Query query, VectorSearchOperation.SearchType searchType, - Class outputType, @Nullable Integer numCandidates, ScoringFunction scoringFunction) { - - /** - * Create the Aggregation Pipeline. - * - * @param queryMethod - * @param accessor - * @return - */ - public List getAggregationPipeline(MongoQueryMethod queryMethod, - MongoParameterAccessor accessor) { - - Vector vector = accessor.getVector(); - Score score = accessor.getScore(); - Range distance = accessor.getScoreRange(); - Limit limit = Limit.unlimited(); - - if (query.isLimited()) { - limit = Limit.of(query.getLimit()); - } - - List stages = new ArrayList<>(); - VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(queryMethod.getAnnotatedHint()).path(path()) - .vector(vector).limit(limit); - - if (numCandidates() != null) { - $vectorSearch = $vectorSearch.numCandidates(numCandidates()); - } - - $vectorSearch = $vectorSearch.filter(query.getQueryObject()); - $vectorSearch = $vectorSearch.searchType(searchType()); - $vectorSearch = $vectorSearch.withSearchScore(scoreField()); - - if (score != null) { - $vectorSearch = $vectorSearch.withFilterBySore(c -> { - c.gt(score.getValue()); - }); - } else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) { - $vectorSearch = $vectorSearch.withFilterBySore(c -> { - Range.Bound lower = distance.getLowerBound(); - if (lower.isBounded()) { - double value = lower.getValue().get().getValue(); - if (lower.isInclusive()) { - c.gte(value); - } else { - c.gt(value); - } - } - - Range.Bound upper = distance.getUpperBound(); - if (upper.isBounded()) { - - double value = upper.getValue().get().getValue(); - if (upper.isInclusive()) { - c.lte(value); - } else { - c.lt(value); - } - } - }); - } - - stages.add($vectorSearch); - - if (query.isSorted()) { - // TODO stages.add(Aggregation.sort(query.with())); - } else { - stages.add(Aggregation.sort(Sort.Direction.DESC, "__score__")); - } - - return stages; - } + record QueryContainer(String path, String scoreField, Query query, AggregationPipeline pipeline, + VectorSearchOperation.SearchType searchType, Class outputType, ScoringFunction scoringFunction, String index) { } @@ -368,11 +393,12 @@ class VectorSearchDelegate { this.tree = tree; } + @SuppressWarnings("NullAway") public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, ParameterBindingContext context) { - MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), - false, true); + MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), false, + true); Query query = creator.createQuery(parameterAccessor.getSort()); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java index 028a6926f..a224481da 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java @@ -81,16 +81,15 @@ public class VectorSearchTests { @Override public MongoClient mongoClient() { - atlasLocal.start(); return MongoClients.create(atlasLocal.getConnectionString()); } } @BeforeAll static void beforeAll() throws InterruptedException { + atlasLocal.start(); - System.out.println(atlasLocal.getConnectionString()); client = MongoClients.create(atlasLocal.getConnectionString()); template = new MongoTestTemplate(client, "vector-search-tests"); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java index c347936df..819bba5a4 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java @@ -34,6 +34,7 @@ import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.repository.CrudRepository; @@ -68,7 +69,7 @@ class VectorSearchAggregationUnitTests { VectorSearchAggregation aggregation = aggregation(SampleRepository.class, "searchByCountryAndEmbeddingNear", String.class, Vector.class, Score.class, Limit.class); - VectorSearchDelegate.QueryMetadata query = aggregation.createVectorSearchQuery( + QueryContainer query = aggregation.createVectorSearchQuery( aggregation.getQueryMethod().getResultProcessor(), new MongoParametersParameterAccessor(aggregation.getQueryMethod(), new Object[] { "de", Vector.of(1f), Score.of(1), Limit.unlimited() }), diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java index 06a80e78f..078c01eec 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java @@ -15,23 +15,30 @@ */ package org.springframework.data.mongodb.repository.query; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import java.lang.reflect.Method; +import java.util.List; +import org.bson.Document; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; - import org.springframework.data.domain.Limit; import org.springframework.data.domain.Score; import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Vector; import org.springframework.data.mapping.model.ValueExpressionEvaluator; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; +import org.springframework.data.mongodb.util.aggregation.TestAggregationContext; import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; @@ -44,6 +51,7 @@ import org.springframework.data.repository.query.ValueExpressionDelegate; * Unit tests for {@link VectorSearchDelegate}. * * @author Mark Paluch + * @author Christoph Strobl */ class VectorSearchDelegateUnitTests { @@ -57,10 +65,10 @@ class VectorSearchDelegateUnitTests { MongoQueryMethod queryMethod = getMongoQueryMethod(method); MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1)); - VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); + QueryContainer container = createQueryContainer(queryMethod, accessor); - assertThat(query.query().getLimit()).isEqualTo(10); - assertThat(query.numCandidates()).isEqualTo(10 * 20); + assertThat(container.query().getLimit()).isEqualTo(10); + assertThat(numCandidates(container.pipeline())).isEqualTo(10 * 20); } @Test @@ -71,10 +79,10 @@ class VectorSearchDelegateUnitTests { MongoQueryMethod queryMethod = getMongoQueryMethod(method); MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1)); - VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); + QueryContainer container = createQueryContainer(queryMethod, accessor); - assertThat(query.query().getLimit()).isEqualTo(10); - assertThat(query.numCandidates()).isNull(); + assertThat(container.query().getLimit()).isEqualTo(10); + assertThat(numCandidates(container.pipeline())).isNull(); } @Test @@ -86,19 +94,87 @@ class VectorSearchDelegateUnitTests { MongoQueryMethod queryMethod = getMongoQueryMethod(method); MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11)); - VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); + QueryContainer container = createQueryContainer(queryMethod, accessor); - assertThat(query.query().getLimit()).isEqualTo(11); - assertThat(query.numCandidates()).isEqualTo(11 * 20); + assertThat(container.query().getLimit()).isEqualTo(11); + assertThat(numCandidates(container.pipeline())).isEqualTo(11 * 20); } - private VectorSearchDelegate.QueryMetadata createQueryMetadata(MongoQueryMethod queryMethod, - MongoParametersParameterAccessor accessor) { + @Test + void considersDerivedQueryPart() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByFirstNameAndEmbeddingNear", String.class, + Vector.class, Score.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, "spring", Vector.of(1, 2), Score.of(1)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter", + new Document("first_name", "spring")); + } + + @Test + void considersDerivedQueryPartInDifferentOrder() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNearAndFirstName", Vector.class, + Score.class, String.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), "spring"); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter", + new Document("first_name", "spring")); + } + + @Test + void defaultSortsByScore() throws NoSuchMethodException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class, + Limit.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(10)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + List stages = container.pipeline().lastOperation() + .toPipelineStages(TestAggregationContext.contextFor(WithVector.class)); + + assertThat(stages).containsExactly(new Document("$sort", new Document("__score__", -1))); + } + + @Test + void usesDerivedSort() throws NoSuchMethodException { + + Method method = VectorSearchRepository.class.getMethod("searchByEmbeddingNearOrderByFirstName", Vector.class, + Score.class, Limit.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + AggregationPipeline aggregationPipeline = container.pipeline(); + + List stages = aggregationPipeline.lastOperation() + .toPipelineStages(TestAggregationContext.contextFor(WithVector.class)); + + assertThat(stages).containsExactly(new Document("$sort", new Document("first_name", 1).append("__score__", -1))); + } + + Document vectorSearchStageOf(AggregationPipeline pipeline) { + return pipeline.firstOperation().toPipelineStages(TestAggregationContext.contextFor(WithVector.class)).get(0); + } + + private QueryContainer createQueryContainer(MongoQueryMethod queryMethod, MongoParametersParameterAccessor accessor) { VectorSearchDelegate delegate = new VectorSearchDelegate(queryMethod, converter, ValueExpressionDelegate.create()); - return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor, - Object.class, new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class)); + return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor, null, + new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class)); } private MongoQueryMethod getMongoQueryMethod(Method method) { @@ -110,21 +186,69 @@ class VectorSearchDelegateUnitTests { return new MongoParametersParameterAccessor(queryMethod, values); } + @Nullable + private static Integer numCandidates(AggregationPipeline pipeline) { + + Document $vectorSearch = pipeline.firstOperation().toPipelineStages(Aggregation.DEFAULT_CONTEXT).get(0); + if ($vectorSearch.containsKey("$vectorSearch")) { + Object value = $vectorSearch.get("$vectorSearch", Document.class).get("numCandidates"); + return value instanceof Number i ? i.intValue() : null; + } + return null; + } + interface VectorSearchRepository extends Repository { @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity); + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchTop10ByFirstNameAndEmbeddingNear(String firstName, Vector vector, Score similarity); + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchTop10ByEmbeddingNearAndFirstName(Vector vector, Score similarity, String firstname); + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ENN) SearchResults searchTop10EnnByEmbeddingNear(Vector vector, Score similarity); @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity, Limit limit); + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchByEmbeddingNearOrderByFirstName(Vector vector, Score similarity, Limit limit); + } static class WithVector { Vector embedding; + + String lastName; + + @Field("first_name") String firstName; + + public Vector getEmbedding() { + return embedding; + } + + public void setEmbedding(Vector embedding) { + this.embedding = embedding; + } + + public String getLastName() { + return lastName; + } + + public void setLastName(String lastName) { + this.lastName = lastName; + } + + public String getFirstName() { + return firstName; + } + + public void setFirstName(String firstName) { + this.firstName = firstName; + } } } diff --git a/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc b/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc index 345b5dbb6..7fc51de00 100644 --- a/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc +++ b/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc @@ -25,7 +25,7 @@ Java:: [source,java,indent=0,subs="verbatim,quotes",role="primary"] ---- VectorIndex index = new VectorIndex("vector_index") - .addVector("plotEmbedding"), vector -> vector.dimensions(1536).similarity(COSINE)) <1> + .addVector("plotEmbedding", vector -> vector.dimensions(1536).similarity(COSINE)) <1> .addFilter("year"); <2> mongoTemplate.searchIndexOps(Movie.class) <3> diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc index 752ffad62..252437f0b 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc @@ -6,13 +6,13 @@ Annotated search methods use the `@VectorSearch` annotation to define parameters ---- interface CommentRepository extends Repository { - @VectorSearch(indexName = "cos-index", filter = "{country: ?0}") - SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + @VectorSearch(indexName = "cos-index", filter = "{country: ?0}", limit="100", numCandidates="2000") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); - @VectorSearch(indexName = "my-index", filter = "{country: ?0}", numCandidates = "#{#limit * 20}", + @VectorSearch(indexName = "my-index", filter = "{country: ?0}", limit="?3", numCandidates = "#{#limit * 20}", searchType = VectorSearchOperation.SearchType.ANN) - List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit); + List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit); } ---- ==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc index dd06ee699..f2b006b8e 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc @@ -6,14 +6,14 @@ MongoDB Search methods must use the `@VectorSearch` annotation to define the ind ---- interface CommentRepository extends Repository { - @VectorSearch(indexName = "my-index") - SearchResults searchByEmbeddingNear(Vector vector, Score score); + @VectorSearch(indexName = "my-index", numCandidates="200") + SearchResults searchTop10ByEmbeddingNear(Vector vector, Score score); - @VectorSearch(indexName = "my-index") - SearchResults searchByEmbeddingWithin(Vector vector, Range range); + @VectorSearch(indexName = "my-index", numCandidates="200") + SearchResults searchTop10ByEmbeddingWithin(Vector vector, Range range); - @VectorSearch(indexName = "my-index") - SearchResults searchByCountryAndEmbeddingWithin(String country, Vector vector, Range range); + @VectorSearch(indexName = "my-index", numCandidates="200") + SearchResults searchTop10ByCountryAndEmbeddingWithin(String country, Vector vector, Range range); } ---- ==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc index c7ad91c9d..0e987fc1c 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc @@ -4,12 +4,12 @@ ---- interface CommentRepository extends Repository { - @VectorSearch(indexName = "my-index") + @VectorSearch(indexName = "my-index", numCandidates="#{#limit.max() * 20}") SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Score score, Limit limit); - @VectorSearch(indexName = "my-index") - SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + @VectorSearch(indexName = "my-index", limit="10", numCandidates="200") + SearchResults searchByCountryAndEmbeddingWithin(String country, Vector embedding, Score score); } @@ -17,3 +17,9 @@ interface CommentRepository extends Repository { SearchResults results = repository.searchByCountryAndEmbeddingNear("en", Vector.of(…), Score.of(0.9), Limit.of(10)); ---- ==== + +[TIP] +==== +The MongoDB https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/[vector search aggregation] stage defines a set of required arguments and restrictions. +Please make sure to follow the guidelines and make sure to provide required arguments like `limit`. +==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc index b97475b46..313d8bf39 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc @@ -9,13 +9,13 @@ The scoring function defaults to `ScoringFunction.unspecified()` as there is no interface CommentRepository extends Repository { @VectorSearch(…) - SearchResults searchByEmbeddingNear(Vector vector, Score similarity); + SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity); @VectorSearch(…) - SearchResults searchByEmbeddingNear(Vector vector, Similarity similarity); + SearchResults searchTop10ByEmbeddingNear(Vector vector, Similarity similarity); @VectorSearch(…) - SearchResults searchByEmbeddingNear(Vector vector, Range range); + SearchResults searchTop10ByEmbeddingNear(Vector vector, Range range); } repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9)); <1>