Browse Source

Introduce reusable code fragments for generating code.

See #5004
Original pull request: #5005
pull/5026/head
Christoph Strobl 6 months ago committed by Mark Paluch
parent
commit
8f7fef4c5a
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 563
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java
  2. 44
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/BuilderStyleSnippet.java
  3. 127
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java
  4. 53
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/ExpressionSnippet.java
  5. 63
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java
  6. 19
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java
  7. 10
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java
  8. 47
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java
  9. 91
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java
  10. 224
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/Snippet.java
  11. 42
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java
  12. 119
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VariableSnippet.java
  13. 221
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java
  14. 11
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java
  15. 2
      spring-data-mongodb/src/test/resources/logback.xml

563
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java

@ -19,11 +19,11 @@ import java.util.ArrayList; @@ -19,11 +19,11 @@ import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.bson.Document;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.domain.Sort.Order;
@ -52,312 +52,323 @@ import org.springframework.util.StringUtils; @@ -52,312 +52,323 @@ import org.springframework.util.StringUtils;
*/
class AggregationBlocks {
@NullUnmarked
static class AggregationExecutionCodeBlockBuilder {
private final AotQueryMethodGenerationContext context;
private final MongoQueryMethod queryMethod;
private String aggregationVariableName;
@NullUnmarked
static class AggregationExecutionCodeBlockBuilder {
AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
private final AotQueryMethodGenerationContext context;
private final MongoQueryMethod queryMethod;
private String aggregationVariableName;
this.context = context;
this.queryMethod = queryMethod;
}
AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) {
this.context = context;
this.queryMethod = queryMethod;
}
this.aggregationVariableName = aggregationVariableName;
return this;
}
AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) {
CodeBlock build() {
this.aggregationVariableName = aggregationVariableName;
return this;
}
String mongoOpsRef = context.fieldNameOf(MongoOperations.class);
Builder builder = CodeBlock.builder();
CodeBlock build() {
builder.add("\n");
String mongoOpsRef = context.fieldNameOf(MongoOperations.class);
Builder builder = CodeBlock.builder();
Class<?> outputType = queryMethod.getReturnedObjectType();
if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) {
outputType = Document.class;
} else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) {
outputType = queryMethod.getReturnType().getComponentType().getType();
}
builder.add("\n");
if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) {
builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
return builder.build();
}
Class<?> outputType = queryMethod.getReturnedObjectType();
if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) {
outputType = Document.class;
} else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) {
outputType = queryMethod.getReturnType().getComponentType().getType();
}
if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) {
builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
return builder.build();
}
if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) {
builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
return builder.build();
}
if (outputType == Document.class) {
if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) {
builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
return builder.build();
}
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
if (outputType == Document.class) {
if (queryMethod.isStreamQuery()) {
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class,
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
if (queryMethod.isStreamQuery()) {
builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))",
context.localVariable("results"), returnType);
} else {
VariableSnippet results = Snippet.declare(builder)
.variable(ResolvableType.forClassWithGenerics(Stream.class, Document.class),
context.localVariable("results"))
.as("$L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))",
results.getVariableName(), returnType);
} else {
if (!queryMethod.isCollectionQuery()) {
builder.addStatement(
"return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))",
CollectionUtils.class, returnType, context.localVariable("results"));
} else {
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
context.localVariable("results"));
}
}
} else {
if (queryMethod.isSliceQuery()) {
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()",
context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName());
builder.addStatement(
"return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)",
SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"),
context.getPageableParameterName());
} else {
VariableSnippet results = Snippet.declare(builder)
.variable(AggregationResults.class, context.localVariable("results"))
.as("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
if (queryMethod.isStreamQuery()) {
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
outputType);
} else {
if (!queryMethod.isCollectionQuery()) {
builder.addStatement(
"return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))",
CollectionUtils.class, returnType, results.getVariableName());
} else {
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
results.getVariableName());
}
}
} else {
if (queryMethod.isSliceQuery()) {
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
aggregationVariableName, outputType);
}
}
}
VariableSnippet results = Snippet.declare(builder)
.variable(AggregationResults.class, context.localVariable("results"))
.as("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
return builder.build();
}
}
VariableSnippet hasNext = Snippet.declare(builder).variable("hasNext").as(
"$L.getMappedResults().size() > $L.getPageSize()", results.getVariableName(),
context.getPageableParameterName());
@NullUnmarked
static class AggregationCodeBlockBuilder {
builder.addStatement(
"return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)",
SliceImpl.class, hasNext.getVariableName(), results.getVariableName(),
context.getPageableParameterName());
} else {
private final AotQueryMethodGenerationContext context;
private final MongoQueryMethod queryMethod;
private final Map<String, CodeBlock> arguments;
if (queryMethod.isStreamQuery()) {
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
outputType);
} else {
private AggregationInteraction source;
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
aggregationVariableName, outputType);
}
}
}
private String aggregationVariableName;
private boolean pipelineOnly;
return builder.build();
}
}
AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
@NullUnmarked
static class AggregationCodeBlockBuilder {
this.context = context;
this.arguments = new LinkedHashMap<>();
context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it)));
this.queryMethod = queryMethod;
}
private final AotQueryMethodGenerationContext context;
private final MongoQueryMethod queryMethod;
private final Map<String, CodeBlock> arguments;
AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) {
this.source = aggregation;
return this;
}
private AggregationInteraction source;
AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) {
this.aggregationVariableName = aggregationVariableName;
return this;
}
AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) {
this.pipelineOnly = pipelineOnly;
return this;
}
CodeBlock build() {
Builder builder = CodeBlock.builder();
builder.add("\n");
String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline"));
builder.add(pipeline(pipelineName));
if (!pipelineOnly) {
builder.addStatement("$1T<$2T> $3L = $4T.newAggregation($2T.class, $5L.getOperations())",
TypedAggregation.class, context.getRepositoryInformation().getDomainType(), aggregationVariableName,
Aggregation.class, pipelineName);
builder.add(aggregationOptions(aggregationVariableName));
}
return builder.build();
}
private CodeBlock pipeline(String pipelineVariableName) {
String sortParameter = context.getSortParameterName();
String limitParameter = context.getLimitParameterName();
String pageableParameter = context.getPageableParameterName();
boolean mightBeSorted = StringUtils.hasText(sortParameter);
boolean mightBeLimited = StringUtils.hasText(limitParameter);
boolean mightBePaged = StringUtils.hasText(pageableParameter);
int stageCount = source.stages().size();
if (mightBeSorted) {
stageCount++;
}
if (mightBeLimited) {
stageCount++;
}
if (mightBePaged) {
stageCount += 3;
}
Builder builder = CodeBlock.builder();
builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments));
private String aggregationVariableName;
private boolean pipelineOnly;
if (mightBeSorted) {
builder.add(sortingStage(sortParameter));
}
AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
if (mightBeLimited) {
builder.add(limitingStage(limitParameter));
}
this.context = context;
this.arguments = new LinkedHashMap<>();
context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it)));
this.queryMethod = queryMethod;
}
if (mightBePaged) {
builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery()));
}
builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName,
context.localVariable("stages"));
return builder.build();
}
private CodeBlock aggregationOptions(String aggregationVariableName) {
Builder builder = CodeBlock.builder();
List<CodeBlock> options = new ArrayList<>(5);
if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) {
options.add(CodeBlock.of(".skipOutput()"));
}
AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) {
this.source = aggregation;
return this;
}
MergedAnnotation<Hint> hintAnnotation = context.getAnnotation(Hint.class);
String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null;
if (StringUtils.hasText(hint)) {
options.add(CodeBlock.of(".hint($S)", hint));
}
MergedAnnotation<ReadPreference> readPreferenceAnnotation = context.getAnnotation(ReadPreference.class);
String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null;
if (StringUtils.hasText(readPreference)) {
options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference));
}
AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) {
if (queryMethod.hasAnnotatedCollation()) {
options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation()));
}
if (!options.isEmpty()) {
Builder optionsBuilder = CodeBlock.builder();
optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class,
context.localVariable("aggregationOptions"));
optionsBuilder.indent();
for (CodeBlock optionBlock : options) {
optionsBuilder.add(optionBlock);
optionsBuilder.add("\n");
}
optionsBuilder.add(".build();\n");
optionsBuilder.unindent();
builder.add(optionsBuilder.build());
builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName,
context.localVariable("aggregationOptions"));
}
return builder.build();
}
private CodeBlock aggregationStages(String stageListVariableName, Iterable<String> stages, int stageCount,
Map<String, CodeBlock> arguments) {
Builder builder = CodeBlock.builder();
builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class,
stageCount);
int stageCounter = 0;
for (String stage : stages) {
String stageName = context.localVariable("stage_%s".formatted(stageCounter++));
builder.add(MongoCodeBlocks.renderExpressionToDocument(stage, stageName, arguments));
builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName);
}
return builder.build();
}
private CodeBlock sortingStage(String sortProvider) {
Builder builder = CodeBlock.builder();
builder.beginControlFlow("if ($L.isSorted())", sortProvider);
builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument"));
builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider);
builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);",
context.localVariable("sortDocument"), context.localVariable("order"));
builder.endControlFlow();
builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort",
context.localVariable("sortDocument"));
builder.endControlFlow();
return builder.build();
}
private CodeBlock pagingStage(String pageableProvider, boolean slice) {
Builder builder = CodeBlock.builder();
builder.add(sortingStage(pageableProvider + ".getSort()"));
builder.beginControlFlow("if ($L.isPaged())", pageableProvider);
builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider);
builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class,
pageableProvider);
builder.endControlFlow();
if (slice) {
builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"),
Aggregation.class, pageableProvider);
} else {
builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class,
pageableProvider);
}
builder.endControlFlow();
return builder.build();
}
private CodeBlock limitingStage(String limitProvider) {
Builder builder = CodeBlock.builder();
builder.beginControlFlow("if ($L.isLimited())", limitProvider);
builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class,
limitProvider);
builder.endControlFlow();
return builder.build();
}
}
this.aggregationVariableName = aggregationVariableName;
return this;
}
AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) {
this.pipelineOnly = pipelineOnly;
return this;
}
CodeBlock build() {
Builder builder = CodeBlock.builder();
builder.add("\n");
String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline"));
builder.add(pipeline(pipelineName));
if (!pipelineOnly) {
Class<?> domainType = context.getRepositoryInformation().getDomainType();
Snippet.declare(builder)
.variable(ResolvableType.forClassWithGenerics(TypedAggregation.class, domainType), aggregationVariableName)
.as("$T.newAggregation($T.class, $L.getOperations())", Aggregation.class, domainType, pipelineName);
builder.add(aggregationOptions(aggregationVariableName));
}
return builder.build();
}
private CodeBlock pipeline(String pipelineVariableName) {
String sortParameter = context.getSortParameterName();
String limitParameter = context.getLimitParameterName();
String pageableParameter = context.getPageableParameterName();
boolean mightBeSorted = StringUtils.hasText(sortParameter);
boolean mightBeLimited = StringUtils.hasText(limitParameter);
boolean mightBePaged = StringUtils.hasText(pageableParameter);
int stageCount = source.stages().size();
if (mightBeSorted) {
stageCount++;
}
if (mightBeLimited) {
stageCount++;
}
if (mightBePaged) {
stageCount += 3;
}
Builder builder = CodeBlock.builder();
builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments));
if (mightBeSorted) {
builder.add(sortingStage(sortParameter));
}
if (mightBeLimited) {
builder.add(limitingStage(limitParameter));
}
if (mightBePaged) {
builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery()));
}
builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName,
context.localVariable("stages"));
return builder.build();
}
private CodeBlock aggregationOptions(String aggregationVariableName) {
Builder builder = CodeBlock.builder();
List<CodeBlock> options = new ArrayList<>(5);
if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) {
options.add(CodeBlock.of(".skipOutput()"));
}
MergedAnnotation<Hint> hintAnnotation = context.getAnnotation(Hint.class);
String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null;
if (StringUtils.hasText(hint)) {
options.add(CodeBlock.of(".hint($S)", hint));
}
MergedAnnotation<ReadPreference> readPreferenceAnnotation = context.getAnnotation(ReadPreference.class);
String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null;
if (StringUtils.hasText(readPreference)) {
options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference));
}
if (queryMethod.hasAnnotatedCollation()) {
options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation()));
}
if (!options.isEmpty()) {
Builder optionsBuilder = CodeBlock.builder();
optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class,
context.localVariable("aggregationOptions"));
optionsBuilder.indent();
for (CodeBlock optionBlock : options) {
optionsBuilder.add(optionBlock);
optionsBuilder.add("\n");
}
optionsBuilder.add(".build();\n");
optionsBuilder.unindent();
builder.add(optionsBuilder.build());
builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName,
context.localVariable("aggregationOptions"));
}
return builder.build();
}
private CodeBlock aggregationStages(String stageListVariableName, Iterable<String> stages, int stageCount,
Map<String, CodeBlock> arguments) {
Builder builder = CodeBlock.builder();
builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class,
stageCount);
int stageCounter = 0;
for (String stage : stages) {
VariableSnippet stageSnippet = Snippet.declare(builder)
.variable(Document.class, context.localVariable("stage_%s".formatted(stageCounter++)))
.of(MongoCodeBlocks.asDocument(stage, arguments));
builder.addStatement("$L.add($L)", stageListVariableName, stageSnippet.getVariableName());
}
return builder.build();
}
private CodeBlock sortingStage(String sortProvider) {
Builder builder = CodeBlock.builder();
builder.beginControlFlow("if ($L.isSorted())", sortProvider);
builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument"));
builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider);
builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);",
context.localVariable("sortDocument"), context.localVariable("order"));
builder.endControlFlow();
builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort",
context.localVariable("sortDocument"));
builder.endControlFlow();
return builder.build();
}
private CodeBlock pagingStage(String pageableProvider, boolean slice) {
Builder builder = CodeBlock.builder();
builder.add(sortingStage(pageableProvider + ".getSort()"));
builder.beginControlFlow("if ($L.isPaged())", pageableProvider);
builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider);
builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class,
pageableProvider);
builder.endControlFlow();
if (slice) {
builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"),
Aggregation.class, pageableProvider);
} else {
builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class,
pageableProvider);
}
builder.endControlFlow();
return builder.build();
}
private CodeBlock limitingStage(String limitProvider) {
Builder builder = CodeBlock.builder();
builder.beginControlFlow("if ($L.isLimited())", limitProvider);
builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class,
limitProvider);
builder.endControlFlow();
return builder.build();
}
}
}

44
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/BuilderStyleSnippet.java

@ -0,0 +1,44 @@ @@ -0,0 +1,44 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
/**
* @author Christoph Strobl
*/
public class BuilderStyleSnippet implements Snippet {
private final String targetVariableName;
private final String methodName;
private final Snippet argumentValue;
BuilderStyleSnippet(String targetVariableName, String methodName, Snippet argumentValue) {
this.targetVariableName = targetVariableName;
this.methodName = methodName;
this.argumentValue = argumentValue;
}
@Override
public CodeBlock code() {
Builder builder = CodeBlock.builder();
builder.add("$1L = $1L.$2L($3L);\n", targetVariableName, methodName, argumentValue.code());
return builder.build();
}
}

127
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java

@ -18,6 +18,7 @@ package org.springframework.data.mongodb.repository.aot; @@ -18,6 +18,7 @@ package org.springframework.data.mongodb.repository.aot;
import java.util.Optional;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.core.ResolvableType;
import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.repository.query.MongoQueryExecution.DeleteExecution;
@ -35,66 +36,68 @@ import org.springframework.util.ObjectUtils; @@ -35,66 +36,68 @@ import org.springframework.util.ObjectUtils;
*/
class DeleteBlocks {
@NullUnmarked
static class DeleteExecutionCodeBlockBuilder {
private final AotQueryMethodGenerationContext context;
private final MongoQueryMethod queryMethod;
private String queryVariableName;
DeleteExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
this.context = context;
this.queryMethod = queryMethod;
}
DeleteExecutionCodeBlockBuilder referencing(String queryVariableName) {
this.queryVariableName = queryVariableName;
return this;
}
CodeBlock build() {
String mongoOpsRef = context.fieldNameOf(MongoOperations.class);
Builder builder = CodeBlock.builder();
Class<?> domainType = context.getRepositoryInformation().getDomainType();
boolean isProjecting = context.getActualReturnType() != null
&& !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType());
Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType;
builder.add("\n");
builder.addStatement("$1T<$2T> $3L = $4L.remove($2T.class)", ExecutableRemove.class, domainType,
context.localVariable("remover"), mongoOpsRef);
DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL;
if (!queryMethod.isCollectionQuery()) {
if (!ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())) {
type = DeleteExecution.Type.FIND_AND_REMOVE_ONE;
} else {
type = DeleteExecution.Type.ALL;
}
}
actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())
? TypeName.get(context.getMethod().getReturnType())
: queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType;
if (ClassUtils.isVoidType(context.getMethod().getReturnType())) {
builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"),
DeleteExecution.Type.class, type.name(), queryVariableName);
} else if (context.getMethod().getReturnType() == Optional.class) {
builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class,
actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class,
type.name(), queryVariableName);
} else {
builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class,
context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName);
}
return builder.build();
}
}
@NullUnmarked
static class DeleteExecutionCodeBlockBuilder {
private final AotQueryMethodGenerationContext context;
private final MongoQueryMethod queryMethod;
private String queryVariableName;
DeleteExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
this.context = context;
this.queryMethod = queryMethod;
}
DeleteExecutionCodeBlockBuilder referencing(String queryVariableName) {
this.queryVariableName = queryVariableName;
return this;
}
CodeBlock build() {
String mongoOpsRef = context.fieldNameOf(MongoOperations.class);
Builder builder = CodeBlock.builder();
Class<?> domainType = context.getRepositoryInformation().getDomainType();
boolean isProjecting = context.getActualReturnType() != null
&& !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType());
Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType;
builder.add("\n");
VariableSnippet remover = Snippet.declare(builder)
.variable(ResolvableType.forClassWithGenerics(ExecutableRemove.class, domainType),
context.localVariable("remover"))
.as("$L.remove($T.class)", mongoOpsRef, domainType);
DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL;
if (!queryMethod.isCollectionQuery()) {
if (!ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())) {
type = DeleteExecution.Type.FIND_AND_REMOVE_ONE;
} else {
type = DeleteExecution.Type.ALL;
}
}
actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())
? TypeName.get(context.getMethod().getReturnType())
: queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType;
if (ClassUtils.isVoidType(context.getMethod().getReturnType())) {
builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, remover.getVariableName(),
DeleteExecution.Type.class, type.name(), queryVariableName);
} else if (context.getMethod().getReturnType() == Optional.class) {
builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class,
actualReturnType, DeleteExecution.class, remover.getVariableName(), DeleteExecution.Type.class, type.name(),
queryVariableName);
} else {
builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class,
context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName);
}
return builder.build();
}
}
}

53
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/ExpressionSnippet.java

@ -0,0 +1,53 @@ @@ -0,0 +1,53 @@
/*
* Copyright 2025-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import org.springframework.javapoet.CodeBlock;
/**
* @author Christoph Strobl
* @since 5.0
*/
class ExpressionSnippet implements Snippet {
private final CodeBlock block;
private final boolean requiresEvaluation;
public ExpressionSnippet(CodeBlock block) {
this(block, false);
}
public ExpressionSnippet(Snippet block) {
this(block.code(), block instanceof ExpressionSnippet eb && eb.requiresEvaluation());
}
public ExpressionSnippet(CodeBlock block, boolean requiresEvaluation) {
this.block = block;
this.requiresEvaluation = requiresEvaluation;
}
public static ExpressionSnippet empty() {
return new ExpressionSnippet(CodeBlock.builder().build());
}
public boolean requiresEvaluation() {
return requiresEvaluation;
}
public CodeBlock code() {
return block;
}
}

63
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java

@ -50,38 +50,37 @@ class GeoBlocks { @@ -50,38 +50,37 @@ class GeoBlocks {
CodeBlock.Builder builder = CodeBlock.builder();
builder.add("\n");
String locationParameterName = context.getParameterName(queryMethod.getParameters().getNearIndex());
builder.addStatement("$1T $2L = $1T.near($3L)", NearQuery.class, variableName, locationParameterName);
VariableSnippet query = Snippet.declare(builder).variable(NearQuery.class, variableName).as("$T.near($L)",
NearQuery.class, context.getParameterName(queryMethod.getParameters().getNearIndex()));
if (queryMethod.getParameters().getRangeIndex() != -1) {
String rangeParametername = context.getParameterName(queryMethod.getParameters().getRangeIndex());
String minVarName = context.localVariable("min");
String maxVarName = context.localVariable("max");
String rangeParameter = context.getParameterName(queryMethod.getParameters().getRangeIndex());
builder.beginControlFlow("if($L.getLowerBound().isBounded())", rangeParametername);
builder.addStatement("$1T $2L = $3L.getLowerBound().getValue().get()", Distance.class, minVarName,
rangeParametername);
builder.addStatement("$1L.minDistance($2L).in($2L.getMetric())", variableName, minVarName);
builder.beginControlFlow("if($L.getLowerBound().isBounded())", rangeParameter);
VariableSnippet min = Snippet.declare(builder).variable(Distance.class, context.localVariable("min"))
.as("$L.getLowerBound().getValue().get()", rangeParameter);
builder.addStatement("$1L.minDistance($2L).in($2L.getMetric())", query.getVariableName(),
min.getVariableName());
builder.endControlFlow();
builder.beginControlFlow("if($L.getUpperBound().isBounded())", rangeParametername);
builder.addStatement("$1T $2L = $3L.getUpperBound().getValue().get()", Distance.class, maxVarName,
rangeParametername);
builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, maxVarName);
builder.beginControlFlow("if($L.getUpperBound().isBounded())", rangeParameter);
VariableSnippet max = Snippet.declare(builder).variable(Distance.class, context.localVariable("max"))
.as("$L.getUpperBound().getValue().get()", rangeParameter);
builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", query.getVariableName(),
max.getVariableName());
builder.endControlFlow();
} else {
String distanceParametername = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex());
builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, distanceParametername);
String distanceParameter = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex());
builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", query.code(), distanceParameter);
}
if (context.getPageableParameterName() != null) {
builder.addStatement("$L.with($L)", variableName, context.getPageableParameterName());
builder.addStatement("$L.with($L)", query.code(), context.getPageableParameterName());
}
MongoCodeBlocks.appendReadPreference(context, builder, variableName);
MongoCodeBlocks.appendReadPreference(context, builder, query.getVariableName());
return builder.build();
}
@ -115,29 +114,29 @@ class GeoBlocks { @@ -115,29 +114,29 @@ class GeoBlocks {
CodeBlock.Builder builder = CodeBlock.builder();
builder.add("\n");
String executorVar = context.localVariable("nearFinder");
builder.addStatement("var $L = $L.query($T.class).near($L)", executorVar,
context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(),
queryVariableName);
VariableSnippet queryExecutor = Snippet.declare(builder).variable(context.localVariable("nearFinder")).as(
"$L.query($T.class).near($L)", context.fieldNameOf(MongoOperations.class),
context.getRepositoryInformation().getDomainType(), queryVariableName);
if (ClassUtils.isAssignable(GeoPage.class, context.getReturnType().getRawClass())) {
String geoResultVar = context.localVariable("geoResult");
builder.addStatement("var $L = $L.all()", geoResultVar, executorVar);
VariableSnippet geoResult = Snippet.declare(builder).variable(context.localVariable("geoResult")).as("$L.all()",
queryExecutor.getVariableName());
builder.beginControlFlow("if($L.isUnpaged())", context.getPageableParameterName());
builder.addStatement("return new $T<>($L)", GeoPage.class, geoResultVar);
builder.addStatement("return new $T<>($L)", GeoPage.class, geoResult.getVariableName());
builder.endControlFlow();
String pageVar = context.localVariable("resultPage");
builder.addStatement("var $L = $T.getPage($L.getContent(), $L, () -> $L.count())", pageVar,
PageableExecutionUtils.class, geoResultVar, context.getPageableParameterName(), executorVar);
builder.addStatement("return new $T<>($L, $L, $L.getTotalElements())", GeoPage.class, geoResultVar,
context.getPageableParameterName(), pageVar);
VariableSnippet resultPage = Snippet.declare(builder).variable(context.localVariable("resultPage")).as(
"$T.getPage($L.getContent(), $L, () -> $L.count())", PageableExecutionUtils.class,
geoResult.getVariableName(), context.getPageableParameterName(), queryExecutor.getVariableName());
builder.addStatement("return new $T<>($L, $L, $L.getTotalElements())", GeoPage.class,
geoResult.getVariableName(), context.getPageableParameterName(), resultPage.getVariableName());
} else if (ClassUtils.isAssignable(GeoResults.class, context.getReturnType().getRawClass())) {
builder.addStatement("return $L.all()", executorVar);
builder.addStatement("return $L.all()", queryExecutor.getVariableName());
} else {
builder.addStatement("return $L.all().getContent()", executorVar);
builder.addStatement("return $L.all().getContent()", queryExecutor.getVariableName());
}
return builder.build();
}

19
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

@ -170,6 +170,25 @@ class MongoCodeBlocks { @@ -170,6 +170,25 @@ class MongoCodeBlocks {
return new GeoNearExecutionCodeBlockBuilder(context, queryMethod);
}
static CodeBlock asDocument(String source, Map<String, CodeBlock> arguments) {
Builder builder = CodeBlock.builder();
if (!StringUtils.hasText(source)) {
builder.add("new $T()", Document.class);
} else if (!containsPlaceholder(source)) {
builder.add("$T.parse($S)", Document.class, source);
} else {
builder.add("bindParameters($S, ", source);
if (containsNamedPlaceholder(source)) {
builder.add(renderArgumentMap(arguments));
} else {
builder.add(renderArgumentArray(arguments));
}
builder.add(");\n");
}
return builder.build();
}
static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName,
Map<String, CodeBlock> arguments) {

10
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

@ -111,7 +111,10 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -111,7 +111,10 @@ public class MongoRepositoryContributor extends RepositoryContributor {
AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method);
if (queryMethod.isSearchQuery() || method.isAnnotationPresent(VectorSearch.class)) {
return searchMethodContributor(queryMethod, new SearchInteraction(query.getQuery()));
VectorSearch vectorSearch = AnnotatedElementUtils.findMergedAnnotation(method, VectorSearch.class);
return searchMethodContributor(queryMethod, new SearchInteraction(getRepositoryInformation().getDomainType(),
vectorSearch, query.getQuery(), queryMethod.getParameters()));
}
if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1
@ -233,8 +236,9 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -233,8 +236,9 @@ public class MongoRepositoryContributor extends RepositoryContributor {
String variableName = "search";
builder.add(new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod)
.usingVariableName(variableName).withFilter(interaction.getFilter()).build());
builder.add(
new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod, interaction.getSearchPath())
.usingVariableName(variableName).withFilter(interaction.getFilter()).build());
return builder.build();
});

47
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java

@ -22,7 +22,6 @@ import java.util.Optional; @@ -22,7 +22,6 @@ import java.util.Optional;
import org.bson.Document;
import org.jspecify.annotations.NullUnmarked;
import org.jspecify.annotations.Nullable;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.data.geo.Box;
import org.springframework.data.geo.Circle;
@ -215,9 +214,9 @@ class QueryBlocks { @@ -215,9 +214,9 @@ class QueryBlocks {
if (StringUtils.hasText(source.getQuery().getFieldsString())) {
builder
.add(MongoCodeBlocks.renderExpressionToDocument(source.getQuery().getFieldsString(), "fields", arguments));
builder.addStatement("$L.setFieldsObject(fields)", queryVariableName);
VariableSnippet fields = Snippet.declare(builder).variable(Document.class, context.localVariable("fields"))
.of(MongoCodeBlocks.asDocument(source.getQuery().getFieldsString(), arguments));
builder.addStatement("$L.setFieldsObject($L)", queryVariableName, fields.getVariableName());
}
String sortParameter = context.getSortParameterName();
@ -225,8 +224,9 @@ class QueryBlocks { @@ -225,8 +224,9 @@ class QueryBlocks {
builder.addStatement("$L.with($L)", queryVariableName, sortParameter);
} else if (StringUtils.hasText(source.getQuery().getSortString())) {
builder.add(MongoCodeBlocks.renderExpressionToDocument(source.getQuery().getSortString(), "sort", arguments));
builder.addStatement("$L.setSortObject(sort)", queryVariableName);
VariableSnippet sort = Snippet.declare(builder).variable(Document.class, context.localVariable("sort"))
.of(MongoCodeBlocks.asDocument(source.getQuery().getSortString(), arguments));
builder.addStatement("$L.setSortObject($L)", queryVariableName, sort.getVariableName());
}
String limitParameter = context.getLimitParameterName();
@ -273,10 +273,10 @@ class QueryBlocks { @@ -273,10 +273,10 @@ class QueryBlocks {
if (collationAnnotation.isPresent()) {
String collationString = collationAnnotation.getString("value");
if(StringUtils.hasText(collationString)) {
if (StringUtils.hasText(collationString)) {
if (!MongoCodeBlocks.containsPlaceholder(collationString)) {
builder.addStatement("$L.collation($T.parse($S))", queryVariableName,
org.springframework.data.mongodb.core.query.Collation.class, collationString);
org.springframework.data.mongodb.core.query.Collation.class, collationString);
} else {
builder.add("$L.collation(collationOf(evaluate($S, ", queryVariableName, collationString);
builder.add(MongoCodeBlocks.renderArgumentMap(arguments));
@ -292,29 +292,28 @@ class QueryBlocks { @@ -292,29 +292,28 @@ class QueryBlocks {
Builder builder = CodeBlock.builder();
builder.add("\n");
builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName));
Snippet.declare(builder).variable(BasicQuery.class, this.queryVariableName).of(renderExpressionToQuery());
return builder.build();
}
private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) {
private CodeBlock renderExpressionToQuery() {
Builder builder = CodeBlock.builder();
String source = this.source.getQuery().getQueryString();
if (!StringUtils.hasText(source)) {
builder.addStatement("$1T $2L = new $1T(new $3T())", BasicQuery.class, variableName, Document.class);
} else if (!MongoCodeBlocks.containsPlaceholder(source)) {
builder.addStatement("$1T $2L = new $1T($3T.parse($4S))", BasicQuery.class, variableName, Document.class,
source);
return CodeBlock.of("new $T(new $T())", BasicQuery.class, Document.class);
}
if (!MongoCodeBlocks.containsPlaceholder(source)) {
return CodeBlock.of("new $T($T.parse($S))", BasicQuery.class, Document.class, source);
}
Builder builder = CodeBlock.builder();
builder.add("createQuery($S, ", source);
if (MongoCodeBlocks.containsNamedPlaceholder(source)) {
builder.add(MongoCodeBlocks.renderArgumentMap(arguments));
} else {
builder.add("$T $L = createQuery($S, ", BasicQuery.class, variableName, source);
if (MongoCodeBlocks.containsNamedPlaceholder(source)) {
builder.add(MongoCodeBlocks.renderArgumentMap(arguments));
} else {
builder.add(MongoCodeBlocks.renderArgumentArray(arguments));
}
builder.add(");\n");
builder.add(MongoCodeBlocks.renderArgumentArray(arguments));
}
builder.add(")");
return builder.build();
}
}

91
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java

@ -15,26 +15,55 @@ @@ -15,26 +15,55 @@
*/
package org.springframework.data.mongodb.repository.aot;
import java.lang.reflect.Field;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.bson.Document;
import org.bson.json.JsonMode;
import org.bson.json.JsonWriterSettings;
import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.MongoParameters;
import org.springframework.data.repository.aot.generate.QueryMetadata;
import org.springframework.util.StringUtils;
/**
* @author Christoph Strobl
*/
public class SearchInteraction extends MongoInteraction implements QueryMetadata {
StringQuery filter;
private final Class<?> domainType;
private final StringQuery filter;
private final @Nullable VectorSearch vectorSearch;
private final MongoParameters parameters;
public SearchInteraction(Class<?> domainType, @Nullable VectorSearch vectorSearch, StringQuery filter,
MongoParameters parameters) {
this.domainType = domainType;
this.vectorSearch = vectorSearch;
public SearchInteraction(StringQuery filter) {
this.filter = filter;
this.parameters = parameters;
}
public StringQuery getFilter() {
return filter;
}
@Nullable
String getIndexName() {
return vectorSearch != null ? vectorSearch.indexName() : null;
}
public MongoParameters getParameters() {
return parameters;
}
@Override
InteractionType getExecutionType() {
return InteractionType.AGGREGATION;
@ -43,6 +72,62 @@ public class SearchInteraction extends MongoInteraction implements QueryMetadata @@ -43,6 +72,62 @@ public class SearchInteraction extends MongoInteraction implements QueryMetadata
@Override
public Map<String, Object> serialize() {
return Map.of("FIXME", "please!");
Map<String, Object> serialized = new LinkedHashMap<>();
if (vectorSearch != null && StringUtils.hasText(vectorSearch.indexName())) {
serialized.put("index", vectorSearch.indexName());
}
serialized.put("path", getSearchPath());
if (vectorSearch.searchType().equals(SearchType.ENN)) {
serialized.put("exact", true);
}
if (StringUtils.hasText(filter.getQueryString())) {
serialized.put("filter", filter.getQueryString());
}
String limit = limitParameter();
if (StringUtils.hasText(limit)) {
serialized.put("limit", limit);
}
if (StringUtils.hasText(vectorSearch.numCandidates())) {
serialized.put("numCandidates", vectorSearch.numCandidates());
} else if (StringUtils.hasText(limit)) {
serialized.put("numCandidates", limit + " * 20");
}
serialized.put("queryVector", "?" + parameters.getVectorIndex());
return Map.of("pipeline", List.of(new Document("$vectorSearch", serialized)
.toJson(JsonWriterSettings.builder().outputMode(JsonMode.RELAXED).build()).replaceAll("\\\"", "'")));
}
private @Nullable String limitParameter() {
if (parameters.hasLimitParameter()) {
return "?" + parameters.getLimitIndex();
} else if (StringUtils.hasText(vectorSearch.limit())) {
return vectorSearch.limit();
}
return null;
}
public String getSearchPath() {
if (vectorSearch != null && StringUtils.hasText(vectorSearch.path())) {
return vectorSearch.path();
}
Field[] declaredFields = domainType.getDeclaredFields();
for (Field field : declaredFields) {
if (Vector.class.isAssignableFrom(field.getType())) {
return field.getName();
}
}
throw new IllegalArgumentException("No vector search path found for type %s".formatted(domainType));
}
}

224
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/Snippet.java

@ -0,0 +1,224 @@ @@ -0,0 +1,224 @@
/*
* Copyright 2025-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import java.util.function.Function;
import org.jspecify.annotations.Nullable;
import org.springframework.core.ResolvableType;
import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleBuilder.BuilderStyleMethodArgumentBuilder;
import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleVariableBuilder.BuilderStyleVariableBuilderImpl;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
/**
* @author Christoph Strobl
* @since 5.0
*/
interface Snippet {
CodeBlock code();
default boolean isEmpty() {
return code().isEmpty();
}
default void appendTo(CodeBlock.Builder builder) {
if (!isEmpty()) {
builder.add(code());
}
}
default <T> T as(Function<? super Snippet, T> transformer) {
return transformer.apply(this);
}
default Snippet wrap(String prefix, String suffix) {
return wrap("%s$L%s".formatted(prefix, suffix));
}
default Snippet wrap(CodeBlock prefix, CodeBlock suffix) {
return new Snippet() {
@Override
public CodeBlock code() {
return CodeBlock.builder().add(prefix).add(Snippet.this.code()).add(suffix).build();
}
};
}
default Snippet wrap(String statement) {
return new Snippet() {
@Override
public CodeBlock code() {
return CodeBlock.of(statement, Snippet.this.code());
}
};
}
static Snippet just(CodeBlock codeBlock) {
return new Snippet() {
@Override
public CodeBlock code() {
return codeBlock;
}
};
}
static ContextualSnippetBuilder declare(CodeBlock.Builder builder) {
return new ContextualSnippetBuilder() {
@Override
public VariableBuilder variable(String variableName) {
return VariableSnippet.variable(variableName).targeting(builder);
}
@Override
public VariableBuilder variable(Class<?> type, String variableName) {
return VariableSnippet.variable(type, variableName).targeting(builder);
}
@Override
public VariableBuilder variable(ResolvableType resolvableType, String variableName) {
return VariableSnippet.variable(resolvableType, variableName).targeting(builder);
}
@Override
public BuilderStyleVariableBuilder variableBuilder(String variableName) {
return new BuilderStyleVariableBuilderImpl(builder, null, variableName);
}
@Override
public BuilderStyleVariableBuilder variableBuilder(Class<?> type, String variableName) {
return variableBuilder(ResolvableType.forClass(type), variableName);
}
@Override
public BuilderStyleVariableBuilder variableBuilder(ResolvableType resolvableType, String variableName) {
return new BuilderStyleVariableBuilderImpl(builder, resolvableType, variableName);
}
};
}
interface ContextualSnippetBuilder {
VariableBuilder variable(String variableName);
VariableBuilder variable(Class<?> type, String variableName);
VariableBuilder variable(ResolvableType resolvableType, String variableName);
BuilderStyleVariableBuilder variableBuilder(String variableName);
BuilderStyleVariableBuilder variableBuilder(Class<?> type, String variableName);
BuilderStyleVariableBuilder variableBuilder(ResolvableType resolvableType, String variableName);
}
interface VariableBuilder {
default VariableSnippet as(String declaration, Object... args) {
return of(CodeBlock.of(declaration, args));
}
VariableSnippet of(CodeBlock codeBlock);
}
interface BuilderStyleVariableBuilder {
default BuilderStyleBuilder as(String declaration, Object... args) {
return of(CodeBlock.of(declaration, args));
}
BuilderStyleBuilder of(CodeBlock codeBlock);
class BuilderStyleVariableBuilderImpl
implements BuilderStyleVariableBuilder, BuilderStyleBuilder, BuilderStyleMethodArgumentBuilder {
Builder targetBuilder;
@Nullable ResolvableType type;
String targetVariableName;
@Nullable String targetMethodName;
@Nullable VariableSnippet variableSnippet;
public BuilderStyleVariableBuilderImpl(Builder targetBuilder, @Nullable ResolvableType type,
String targetVariableName) {
this.targetBuilder = targetBuilder;
this.type = type;
this.targetVariableName = targetVariableName;
}
@Override
public BuilderStyleBuilder as(String declaration, Object... args) {
if (type != null) {
this.variableSnippet = Snippet.declare(targetBuilder).variable(type, targetVariableName).as(declaration, args);
} else {
this.variableSnippet = Snippet.declare(targetBuilder).variable(targetVariableName).as(declaration, args);
}
return this;
}
@Override
public BuilderStyleBuilder of(CodeBlock codeBlock) {
if (type != null) {
this.variableSnippet = Snippet.declare(targetBuilder).variable(type, targetVariableName).of(codeBlock);
} else {
this.variableSnippet = Snippet.declare(targetBuilder).variable(targetVariableName).of(codeBlock);
}
return this;
}
@Override
public BuilderStyleMethodArgumentBuilder call(String methodName) {
this.targetMethodName = methodName;
return this;
}
@Override
public BuilderStyleBuilder with(Snippet snippet) {
new BuilderStyleSnippet(targetVariableName, targetMethodName, snippet).appendTo(targetBuilder);
return this;
}
@Override
public VariableSnippet variable() {
return this.variableSnippet;
}
}
}
interface BuilderStyleBuilder {
BuilderStyleMethodArgumentBuilder call(String methodName);
VariableSnippet variable();
interface BuilderStyleMethodArgumentBuilder {
default BuilderStyleBuilder with(String statement, Object... args) {
return with(CodeBlock.of(statement, args));
}
default BuilderStyleBuilder with(CodeBlock codeBlock) {
return with(Snippet.just(codeBlock));
}
BuilderStyleBuilder with(Snippet snippet);
}
}
}

42
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java

@ -1,19 +1,3 @@ @@ -1,19 +1,3 @@
/*
* Copyright 2025. the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Copyright 2025 the original author or authors.
*
@ -21,7 +5,7 @@ @@ -21,7 +5,7 @@
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@ -35,6 +19,7 @@ import java.util.LinkedHashMap; @@ -35,6 +19,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.core.ResolvableType;
import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.query.BasicUpdate;
@ -48,7 +33,6 @@ import org.springframework.util.NumberUtils; @@ -48,7 +33,6 @@ import org.springframework.util.NumberUtils;
/**
* @author Christoph Strobl
* @since 2025/06
*/
class UpdateBlocks {
@ -87,22 +71,26 @@ class UpdateBlocks { @@ -87,22 +71,26 @@ class UpdateBlocks {
String updateReference = updateVariableName;
Class<?> domainType = context.getRepositoryInformation().getDomainType();
builder.addStatement("$1T<$2T> $3L = $4L.update($2T.class)", ExecutableUpdate.class, domainType,
context.localVariable("updater"), mongoOpsRef);
VariableSnippet updater = Snippet.declare(builder)
.variable(ResolvableType.forClassWithGenerics(ExecutableUpdate.class, domainType),
context.localVariable("updater"))
.as("$L.update($T.class)", mongoOpsRef, domainType);
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
if (ReflectionUtils.isVoid(returnType)) {
builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName,
builder.addStatement("$L.matching($L).apply($L).all()", updater.getVariableName(), queryVariableName,
updateReference);
} else if (ClassUtils.isAssignable(Long.class, returnType)) {
builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()",
context.localVariable("updater"), queryVariableName, updateReference);
builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", updater.getVariableName(),
queryVariableName, updateReference);
} else {
builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class,
context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName,
updateReference);
VariableSnippet modifiedCount = Snippet.declare(builder)
.variable(Long.class, context.localVariable("modifiedCount"))
.as("$L.matching($L).apply($L).all().getModifiedCount()", updater.getVariableName(), queryVariableName,
updateReference);
builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class,
context.localVariable("modifiedCount"), returnType);
modifiedCount.getVariableName(), returnType);
}
return builder.build();

119
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VariableSnippet.java

@ -0,0 +1,119 @@ @@ -0,0 +1,119 @@
/*
* Copyright 2025-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import org.jspecify.annotations.Nullable;
import org.springframework.core.ResolvableType;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.javapoet.TypeName;
/**
* @author Christoph Strobl
* @since 5.0
*/
class VariableSnippet extends ExpressionSnippet {
private final String variableName;
private final @Nullable TypeName typeName;
public VariableSnippet(String variableName, Snippet delegate) {
this((TypeName) null, variableName, delegate);
}
public VariableSnippet(Class<?> typeName, String variableName, Snippet delegate) {
this(TypeName.get(typeName), variableName, delegate);
}
public VariableSnippet(@Nullable TypeName typeName, String variableName, Snippet delegate) {
super(delegate);
this.typeName = typeName;
this.variableName = variableName;
}
static VariableBuilderImp variable(String name) {
return new VariableBuilderImp(null, name);
}
static VariableBuilderImp variable(Class<?> typeName, String name) {
return variable(TypeName.get(typeName), name);
}
static VariableBuilderImp variable(ResolvableType resolvableType, String name) {
return variable(TypeName.get(resolvableType.getType()), name);
}
static VariableBuilderImp variable(TypeName typeName, String name) {
return new VariableBuilderImp(typeName, name);
}
static class VariableBuilderImp implements VariableBuilder {
private @Nullable TypeName typeName;
private String variableName;
private CodeBlock.@Nullable Builder target;
VariableBuilderImp(@Nullable TypeName typeName, String variableName) {
this.typeName = typeName;
this.variableName = variableName;
}
@Override
public VariableSnippet of(CodeBlock codeBlock) {
VariableSnippet variableSnippet = new VariableSnippet(typeName, variableName, Snippet.just(codeBlock));
if (target != null) {
variableSnippet.renderDeclaration(target);
}
return variableSnippet;
}
VariableBuilderImp targeting(@Nullable Builder target) {
this.target = target;
return this;
}
}
@Override
public CodeBlock code() {
return CodeBlock.of("$L", variableName);
}
public String getVariableName() {
return variableName;
}
void renderDeclaration(CodeBlock.Builder builder) {
if (typeName != null) {
builder.addStatement("$T $L = $L", typeName, variableName, super.code());
} else {
builder.addStatement("var $L = $L", variableName, super.code());
}
}
static VariableBlockBuilder create(Snippet snippet) {
return variableName -> create(variableName, snippet);
}
static VariableSnippet create(String variableName, Snippet snippet) {
return new VariableSnippet(variableName, snippet);
}
interface VariableBlockBuilder {
VariableSnippet variableName(String variableName);
}
}

221
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java

@ -15,7 +15,6 @@ @@ -15,7 +15,6 @@
*/
package org.springframework.data.mongodb.repository.aot;
import java.lang.reflect.Field;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@ -25,15 +24,14 @@ import org.springframework.core.annotation.MergedAnnotation; @@ -25,15 +24,14 @@ import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleBuilder;
import org.springframework.data.mongodb.repository.query.MongoQueryExecution.VectorSearchExecution;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
@ -55,11 +53,14 @@ class VectorSearchBocks { @@ -55,11 +53,14 @@ class VectorSearchBocks {
private String searchQueryVariableName;
private StringQuery filter;
private final Map<String, CodeBlock> arguments;
private final String searchPath;
VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod,
String searchPath) {
this.context = context;
this.queryMethod = queryMethod;
this.searchPath = searchPath;
this.arguments = new LinkedHashMap<>();
context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it)));
}
@ -77,113 +78,56 @@ class VectorSearchBocks { @@ -77,113 +78,56 @@ class VectorSearchBocks {
String vectorParameterName = context.getVectorParameterName();
MergedAnnotation<VectorSearch> annotation = context.getAnnotation(VectorSearch.class);
String searchPath = annotation.getString("path");
String indexName = annotation.getString("indexName");
String numCandidates = annotation.getString("numCandidates");
SearchType searchType = annotation.getEnum("searchType", SearchType.class);
String limit = annotation.getString("limit");
if (!StringUtils.hasText(searchPath)) { // FIXME: somehow duplicate logic of AnnotatedQueryFactory
ExpressionSnippet limit = getLimitExpression();
Field[] declaredFields = context.getRepositoryInformation().getDomainType().getDeclaredFields();
for (Field field : declaredFields) {
if (Vector.class.isAssignableFrom(field.getType())) {
searchPath = field.getName();
break;
}
}
if (limit.requiresEvaluation() && !StringUtils.hasText(annotation.getString("numCandidates"))
&& (searchType == VectorSearchOperation.SearchType.ANN
|| searchType == VectorSearchOperation.SearchType.DEFAULT)) {
VariableSnippet variableBlock = limit.as(VariableSnippet::create)
.variableName(context.localVariable("limitToUse"));
variableBlock.renderDeclaration(builder);
limit = variableBlock;
}
String vectorSearchVar = context.localVariable("$vectorSearch");
builder.add("$T $L = $T.vectorSearch($S).path($S).vector($L)", VectorSearchOperation.class, vectorSearchVar,
Aggregation.class, indexName, searchPath, vectorParameterName);
if (StringUtils.hasText(context.getLimitParameterName())) {
builder.add(".limit($L);\n", context.getLimitParameterName());
} else if (filter.isLimited()) {
builder.add(".limit($L);\n", filter.getLimit());
} else if (StringUtils.hasText(limit)) {
if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) {
builder.add(".limit(");
builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments));
builder.add(");\n");
} else {
builder.add(".limit($L);\n", limit);
}
} else {
builder.add(".limit($T.unlimited());\n", Limit.class);
}
BuilderStyleBuilder vectorSearchOperationBuilder = Snippet.declare(builder)
.variableBuilder(VectorSearchOperation.class, context.localVariable("$vectorSearch"))
.as("$T.vectorSearch($S).path($S).vector($L).limit($L)", Aggregation.class, indexName, searchPath,
vectorParameterName, limit.code());
if (!searchType.equals(SearchType.DEFAULT)) {
builder.addStatement("$1L = $1L.searchType($2T.$3L)", vectorSearchVar, SearchType.class, searchType.name());
vectorSearchOperationBuilder.call("searchType").with("$T.$L", SearchType.class, searchType.name());
}
if (StringUtils.hasText(numCandidates)) {
builder.add("$1L = $1L.numCandidates(", vectorSearchVar);
builder.add(MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, arguments));
builder.add(");\n");
} else if (searchType == VectorSearchOperation.SearchType.ANN
|| searchType == VectorSearchOperation.SearchType.DEFAULT) {
builder.add(
"// MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return\n");
if (StringUtils.hasText(context.getLimitParameterName())) {
builder.addStatement("$1L = $1L.numCandidates($2L.max() * 20)", vectorSearchVar,
context.getLimitParameterName());
} else if (StringUtils.hasText(limit)) {
if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) {
builder.add("$1L = $1L.numCandidates((", vectorSearchVar);
builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments));
builder.add(") * 20);\n");
} else {
builder.addStatement("$1L = $1L.numCandidates($2L * 20)", vectorSearchVar, limit);
}
} else {
builder.addStatement("$1L = $1L.numCandidates($2L)", vectorSearchVar, filter.getLimit() * 20);
}
ExpressionSnippet numCandidates = getNumCandidatesExpression(searchType, limit);
if (!numCandidates.isEmpty()) {
vectorSearchOperationBuilder.call("numCandidates").with(numCandidates);
}
builder.addStatement("$1L = $1L.withSearchScore(\"__score__\")", vectorSearchVar);
if (StringUtils.hasText(context.getScoreParameterName())) {
vectorSearchOperationBuilder.call("withSearchScore").with("\"__score__\"");
String scoreCriteriaVar = context.localVariable("criteria");
builder.addStatement("$1L = $1L.withFilterBySore($2L -> { $2L.gt($3L.getValue()); })", vectorSearchVar,
scoreCriteriaVar, context.getScoreParameterName());
if (StringUtils.hasText(context.getScoreParameterName())) {
vectorSearchOperationBuilder.call("withFilterBySore").with("$1L -> { $1L.gt($2L.getValue()); }",
context.localVariable("criteria"), context.getScoreParameterName());
} else if (StringUtils.hasText(context.getScoreRangeParameterName())) {
builder.addStatement("$1L = $1L.withFilterBySore(scoreBetween($2L.getLowerBound(), $2L.getUpperBound()))",
vectorSearchVar, context.getScoreRangeParameterName());
vectorSearchOperationBuilder.call("withFilterBySore")
.with("scoreBetween($1L.getLowerBound(), $1L.getUpperBound())", context.getScoreRangeParameterName());
}
if (StringUtils.hasText(filter.getQueryString())) {
String filterVar = context.localVariable("filter");
builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter")
.filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery());
builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar);
builder.add("\n");
}
VariableSnippet vectorSearchOperation = vectorSearchOperationBuilder.variable();
getFilter(vectorSearchOperation.getVariableName()).appendTo(builder);
VariableSnippet sortStage = getSort().as(VariableSnippet::create).variableName(context.localVariable("$sort"));
sortStage.renderDeclaration(builder);
String sortStageVar = context.localVariable("$sort");
if(filter.isSorted()) {
builder.add("$T $L = (_ctx) -> {\n", AggregationOperation.class, sortStageVar);
builder.indent();
builder.addStatement("$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class)", Document.class, filter.getSortString(), context.getActualReturnType().getType());
builder.addStatement("return new $T($S, _mappedSort.append(\"__score__\", -1))", Document.class, "$sort");
builder.unindent();
builder.add("};");
} else {
builder.addStatement("var $L = $T.sort($T.Direction.DESC, $S)", sortStageVar, Aggregation.class, Sort.class, "__score__");
}
builder.add("\n");
builder.addStatement("$1T $2L = new $1T($3T.of($4L, $5L))", AggregationPipeline.class, searchQueryVariableName,
List.class, vectorSearchVar, sortStageVar);
VariableSnippet aggregationPipeline = Snippet.declare(builder)
.variable(AggregationPipeline.class, searchQueryVariableName).as("new $T($T.of($L, $L))",
AggregationPipeline.class, List.class, vectorSearchOperation.getVariableName(), sortStage.code());
String scoringFunctionVar = context.localVariable("scoringFunction");
builder.add("$1T $2L = ", ScoringFunction.class, scoringFunctionVar);
@ -199,13 +143,108 @@ class VectorSearchBocks { @@ -199,13 +143,108 @@ class VectorSearchBocks {
"return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)",
VectorSearchExecution.class, context.fieldNameOf(MongoOperations.class),
context.getRepositoryInformation().getDomainType(), TypeInformation.class,
queryMethod.getReturnType().getType(), searchQueryVariableName, scoringFunctionVar);
queryMethod.getReturnType().getType(), aggregationPipeline.getVariableName(), scoringFunctionVar);
return builder.build();
}
private ExpressionSnippet getSort() {
if (!filter.isSorted()) {
return new ExpressionSnippet(
CodeBlock.of("$T.sort($T.Direction.DESC, $S)", Aggregation.class, Sort.class, "__score__"));
}
Builder builder = CodeBlock.builder();
builder.add("($T) (_ctx) -> {\n", AggregationOperation.class);
builder.indent();
builder.add("$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class);\n", Document.class,
filter.getSortString(), context.getActualReturnType().getType());
builder.add("return new $T($S, _mappedSort.append(\"__score__\", -1));\n", Document.class, "$sort");
builder.unindent();
builder.add("};");
return new ExpressionSnippet(builder.build());
}
private Snippet getFilter(String vectorSearchVar) {
if (!StringUtils.hasText(filter.getQueryString())) {
return ExpressionSnippet.empty();
}
Builder builder = CodeBlock.builder();
String filterVar = context.localVariable("filter");
builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter")
.filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery());
builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar);
builder.add("\n");
return new ExpressionSnippet(builder.build());
}
public VectorSearchQueryCodeBlockBuilder withFilter(StringQuery filter) {
this.filter = filter;
return this;
}
private ExpressionSnippet getNumCandidatesExpression(SearchType searchType, ExpressionSnippet limit) {
MergedAnnotation<VectorSearch> annotation = context.getAnnotation(VectorSearch.class);
String numCandidates = annotation.getString("numCandidates");
if (StringUtils.hasText(numCandidates)) {
if (MongoCodeBlocks.containsPlaceholder(numCandidates) || MongoCodeBlocks.containsExpression(numCandidates)) {
return new ExpressionSnippet(
MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, arguments), true);
} else {
return new ExpressionSnippet(CodeBlock.of("$L", numCandidates));
}
}
if (searchType == VectorSearchOperation.SearchType.ANN
|| searchType == VectorSearchOperation.SearchType.DEFAULT) {
Builder builder = CodeBlock.builder();
if (StringUtils.hasText(context.getLimitParameterName())) {
builder.add("$L.max() * 20", context.getLimitParameterName());
} else if (filter.isLimited()) {
builder.add("$L", filter.getLimit() * 20);
} else {
builder.add("$L * 20", limit.code());
}
return new ExpressionSnippet(builder.build());
}
return ExpressionSnippet.empty();
}
private ExpressionSnippet getLimitExpression() {
if (StringUtils.hasText(context.getLimitParameterName())) {
return new ExpressionSnippet(CodeBlock.of("$L", context.getLimitParameterName()));
}
if (filter.isLimited()) {
return new ExpressionSnippet(CodeBlock.of("$L", filter.getLimit()));
}
MergedAnnotation<VectorSearch> annotation = context.getAnnotation(VectorSearch.class);
String limit = annotation.getString("limit");
if (StringUtils.hasText(limit)) {
if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) {
return new ExpressionSnippet(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments),
true);
} else {
return new ExpressionSnippet(CodeBlock.of("$L", limit));
}
}
return new ExpressionSnippet(CodeBlock.of("$T.unlimited()", Limit.class));
}
}
}

11
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java

@ -332,11 +332,11 @@ public class QueryMethodContributionUnitTests { @@ -332,11 +332,11 @@ public class QueryMethodContributionUnitTests {
Range.class);
assertThat(methodSpec.toString()) //
.containsSubsequence(
"Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(",
.containsSubsequence("var limitToUse = ",
"evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance)")
.containsSubsequence("$vectorSearch.numCandidates(",
"evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance))) * 20)");
.contains(
"Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(limitToUse)")
.contains("$vectorSearch.numCandidates(limitToUse * 20)");
}
@Test
@ -358,7 +358,8 @@ public class QueryMethodContributionUnitTests { @@ -358,7 +358,8 @@ public class QueryMethodContributionUnitTests {
String.class, Vector.class, Range.class);
assertThat(methodSpec.toString()) //
.containsSubsequence("AggregationOperation $sort = (_ctx) -> {", //
.containsSubsequence("var $sort = ", //
"(_ctx) -> {", //
"_mappedSort = _ctx.getMappedObject(", //
"Document.parse(\"{'firstname':{'$numberInt':'1'}}\")", //
"Document(\"$sort\", _mappedSort.append(\"__score__\", -1))");

2
spring-data-mongodb/src/test/resources/logback.xml

@ -20,7 +20,9 @@ @@ -20,7 +20,9 @@
<logger name="org.springframework.data.mongodb.test.util" level="info"/>
<!-- AOT Code Generation -->
<!--
<logger name="org.springframework.data.repository.aot.generate.RepositoryContributor" level="trace" />
-->
<logger name="org.springframework.data.mongodb.core.MongoTemplate" level="debug"/>
<root level="error">

Loading…
Cancel
Save