diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java index 37f24cd84..e40bd518c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java @@ -19,11 +19,11 @@ import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import java.util.stream.Stream; import org.bson.Document; import org.jspecify.annotations.NullUnmarked; +import org.springframework.core.ResolvableType; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.data.domain.SliceImpl; import org.springframework.data.domain.Sort.Order; @@ -52,312 +52,323 @@ import org.springframework.util.StringUtils; */ class AggregationBlocks { - @NullUnmarked - static class AggregationExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String aggregationVariableName; + @NullUnmarked + static class AggregationExecutionCodeBlockBuilder { - AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String aggregationVariableName; - this.context = context; - this.queryMethod = queryMethod; - } + AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) { + this.context = context; + this.queryMethod = queryMethod; + } - this.aggregationVariableName = aggregationVariableName; - return this; - } + AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) { - CodeBlock build() { + this.aggregationVariableName = aggregationVariableName; + return this; + } - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); + CodeBlock build() { - builder.add("\n"); + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); - Class outputType = queryMethod.getReturnedObjectType(); - if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) { - outputType = Document.class; - } else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) { - outputType = queryMethod.getReturnType().getComponentType().getType(); - } + builder.add("\n"); - if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { - builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - return builder.build(); - } + Class outputType = queryMethod.getReturnedObjectType(); + if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) { + outputType = Document.class; + } else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) { + outputType = queryMethod.getReturnType().getComponentType().getType(); + } - if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) { - builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - return builder.build(); - } + if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { + builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); + return builder.build(); + } - if (outputType == Document.class) { + if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) { + builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); + return builder.build(); + } - Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); + if (outputType == Document.class) { - if (queryMethod.isStreamQuery()) { + Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + if (queryMethod.isStreamQuery()) { - builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))", - context.localVariable("results"), returnType); - } else { + VariableSnippet results = Snippet.declare(builder) + .variable(ResolvableType.forClassWithGenerics(Stream.class, Document.class), + context.localVariable("results")) + .as("$L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))", + results.getVariableName(), returnType); + } else { - if (!queryMethod.isCollectionQuery()) { - builder.addStatement( - "return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))", - CollectionUtils.class, returnType, context.localVariable("results")); - } else { - builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, - context.localVariable("results")); - } - } - } else { - if (queryMethod.isSliceQuery()) { - builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()", - context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName()); - builder.addStatement( - "return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)", - SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"), - context.getPageableParameterName()); - } else { + VariableSnippet results = Snippet.declare(builder) + .variable(AggregationResults.class, context.localVariable("results")) + .as("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - if (queryMethod.isStreamQuery()) { - builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, - outputType); - } else { + if (!queryMethod.isCollectionQuery()) { + builder.addStatement( + "return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))", + CollectionUtils.class, returnType, results.getVariableName()); + } else { + builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, + results.getVariableName()); + } + } + } else { + if (queryMethod.isSliceQuery()) { - builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, - aggregationVariableName, outputType); - } - } - } + VariableSnippet results = Snippet.declare(builder) + .variable(AggregationResults.class, context.localVariable("results")) + .as("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - return builder.build(); - } - } + VariableSnippet hasNext = Snippet.declare(builder).variable("hasNext").as( + "$L.getMappedResults().size() > $L.getPageSize()", results.getVariableName(), + context.getPageableParameterName()); - @NullUnmarked - static class AggregationCodeBlockBuilder { + builder.addStatement( + "return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)", + SliceImpl.class, hasNext.getVariableName(), results.getVariableName(), + context.getPageableParameterName()); + } else { - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private final Map arguments; + if (queryMethod.isStreamQuery()) { + builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, + outputType); + } else { - private AggregationInteraction source; + builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, + aggregationVariableName, outputType); + } + } + } - private String aggregationVariableName; - private boolean pipelineOnly; + return builder.build(); + } + } - AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + @NullUnmarked + static class AggregationCodeBlockBuilder { - this.context = context; - this.arguments = new LinkedHashMap<>(); - context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); - this.queryMethod = queryMethod; - } + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private final Map arguments; - AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) { - - this.source = aggregation; - return this; - } + private AggregationInteraction source; - AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) { - - this.aggregationVariableName = aggregationVariableName; - return this; - } - - AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) { - - this.pipelineOnly = pipelineOnly; - return this; - } - - CodeBlock build() { - - Builder builder = CodeBlock.builder(); - builder.add("\n"); - - String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline")); - builder.add(pipeline(pipelineName)); - - if (!pipelineOnly) { - - builder.addStatement("$1T<$2T> $3L = $4T.newAggregation($2T.class, $5L.getOperations())", - TypedAggregation.class, context.getRepositoryInformation().getDomainType(), aggregationVariableName, - Aggregation.class, pipelineName); - - builder.add(aggregationOptions(aggregationVariableName)); - } - - return builder.build(); - } - - private CodeBlock pipeline(String pipelineVariableName) { - - String sortParameter = context.getSortParameterName(); - String limitParameter = context.getLimitParameterName(); - String pageableParameter = context.getPageableParameterName(); - - boolean mightBeSorted = StringUtils.hasText(sortParameter); - boolean mightBeLimited = StringUtils.hasText(limitParameter); - boolean mightBePaged = StringUtils.hasText(pageableParameter); - - int stageCount = source.stages().size(); - if (mightBeSorted) { - stageCount++; - } - if (mightBeLimited) { - stageCount++; - } - if (mightBePaged) { - stageCount += 3; - } - - Builder builder = CodeBlock.builder(); - builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments)); + private String aggregationVariableName; + private boolean pipelineOnly; - if (mightBeSorted) { - builder.add(sortingStage(sortParameter)); - } + AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - if (mightBeLimited) { - builder.add(limitingStage(limitParameter)); - } + this.context = context; + this.arguments = new LinkedHashMap<>(); + context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); + this.queryMethod = queryMethod; + } - if (mightBePaged) { - builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery())); - } - - builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName, - context.localVariable("stages")); - return builder.build(); - } - - private CodeBlock aggregationOptions(String aggregationVariableName) { - - Builder builder = CodeBlock.builder(); - List options = new ArrayList<>(5); - if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { - options.add(CodeBlock.of(".skipOutput()")); - } + AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) { + + this.source = aggregation; + return this; + } - MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); - String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; - if (StringUtils.hasText(hint)) { - options.add(CodeBlock.of(".hint($S)", hint)); - } - - MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); - String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; - if (StringUtils.hasText(readPreference)) { - options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference)); - } + AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) { - if (queryMethod.hasAnnotatedCollation()) { - options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation())); - } - - if (!options.isEmpty()) { - - Builder optionsBuilder = CodeBlock.builder(); - optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class, - context.localVariable("aggregationOptions")); - optionsBuilder.indent(); - for (CodeBlock optionBlock : options) { - optionsBuilder.add(optionBlock); - optionsBuilder.add("\n"); - } - optionsBuilder.add(".build();\n"); - optionsBuilder.unindent(); - builder.add(optionsBuilder.build()); - - builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName, - context.localVariable("aggregationOptions")); - } - return builder.build(); - } - - private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, - Map arguments) { - - Builder builder = CodeBlock.builder(); - builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, - stageCount); - int stageCounter = 0; - - for (String stage : stages) { - String stageName = context.localVariable("stage_%s".formatted(stageCounter++)); - builder.add(MongoCodeBlocks.renderExpressionToDocument(stage, stageName, arguments)); - builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName); - } - - return builder.build(); - } - - private CodeBlock sortingStage(String sortProvider) { - - Builder builder = CodeBlock.builder(); - - builder.beginControlFlow("if ($L.isSorted())", sortProvider); - builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument")); - builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider); - builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);", - context.localVariable("sortDocument"), context.localVariable("order")); - builder.endControlFlow(); - builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort", - context.localVariable("sortDocument")); - builder.endControlFlow(); - - return builder.build(); - } - - private CodeBlock pagingStage(String pageableProvider, boolean slice) { - - Builder builder = CodeBlock.builder(); - - builder.add(sortingStage(pageableProvider + ".getSort()")); - - builder.beginControlFlow("if ($L.isPaged())", pageableProvider); - builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider); - builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class, - pageableProvider); - builder.endControlFlow(); - if (slice) { - builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"), - Aggregation.class, pageableProvider); - } else { - builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class, - pageableProvider); - } - builder.endControlFlow(); - - return builder.build(); - } - - private CodeBlock limitingStage(String limitProvider) { - - Builder builder = CodeBlock.builder(); - - builder.beginControlFlow("if ($L.isLimited())", limitProvider); - builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class, - limitProvider); - builder.endControlFlow(); - - return builder.build(); - } - - } + this.aggregationVariableName = aggregationVariableName; + return this; + } + + AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) { + + this.pipelineOnly = pipelineOnly; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + builder.add("\n"); + + String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline")); + builder.add(pipeline(pipelineName)); + + if (!pipelineOnly) { + + Class domainType = context.getRepositoryInformation().getDomainType(); + Snippet.declare(builder) + .variable(ResolvableType.forClassWithGenerics(TypedAggregation.class, domainType), aggregationVariableName) + .as("$T.newAggregation($T.class, $L.getOperations())", Aggregation.class, domainType, pipelineName); + + builder.add(aggregationOptions(aggregationVariableName)); + } + + return builder.build(); + } + + private CodeBlock pipeline(String pipelineVariableName) { + + String sortParameter = context.getSortParameterName(); + String limitParameter = context.getLimitParameterName(); + String pageableParameter = context.getPageableParameterName(); + + boolean mightBeSorted = StringUtils.hasText(sortParameter); + boolean mightBeLimited = StringUtils.hasText(limitParameter); + boolean mightBePaged = StringUtils.hasText(pageableParameter); + + int stageCount = source.stages().size(); + if (mightBeSorted) { + stageCount++; + } + if (mightBeLimited) { + stageCount++; + } + if (mightBePaged) { + stageCount += 3; + } + + Builder builder = CodeBlock.builder(); + builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments)); + + if (mightBeSorted) { + builder.add(sortingStage(sortParameter)); + } + + if (mightBeLimited) { + builder.add(limitingStage(limitParameter)); + } + + if (mightBePaged) { + builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery())); + } + + builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName, + context.localVariable("stages")); + return builder.build(); + } + + private CodeBlock aggregationOptions(String aggregationVariableName) { + + Builder builder = CodeBlock.builder(); + List options = new ArrayList<>(5); + if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { + options.add(CodeBlock.of(".skipOutput()")); + } + + MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); + String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; + if (StringUtils.hasText(hint)) { + options.add(CodeBlock.of(".hint($S)", hint)); + } + + MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); + String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; + if (StringUtils.hasText(readPreference)) { + options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference)); + } + + if (queryMethod.hasAnnotatedCollation()) { + options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation())); + } + + if (!options.isEmpty()) { + + Builder optionsBuilder = CodeBlock.builder(); + optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class, + context.localVariable("aggregationOptions")); + optionsBuilder.indent(); + for (CodeBlock optionBlock : options) { + optionsBuilder.add(optionBlock); + optionsBuilder.add("\n"); + } + optionsBuilder.add(".build();\n"); + optionsBuilder.unindent(); + builder.add(optionsBuilder.build()); + + builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName, + context.localVariable("aggregationOptions")); + } + return builder.build(); + } + + private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, + Map arguments) { + + Builder builder = CodeBlock.builder(); + builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, + stageCount); + int stageCounter = 0; + + for (String stage : stages) { + + VariableSnippet stageSnippet = Snippet.declare(builder) + .variable(Document.class, context.localVariable("stage_%s".formatted(stageCounter++))) + .of(MongoCodeBlocks.asDocument(stage, arguments)); + builder.addStatement("$L.add($L)", stageListVariableName, stageSnippet.getVariableName()); + } + + return builder.build(); + } + + private CodeBlock sortingStage(String sortProvider) { + + Builder builder = CodeBlock.builder(); + + builder.beginControlFlow("if ($L.isSorted())", sortProvider); + builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument")); + builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider); + builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);", + context.localVariable("sortDocument"), context.localVariable("order")); + builder.endControlFlow(); + builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort", + context.localVariable("sortDocument")); + builder.endControlFlow(); + + return builder.build(); + } + + private CodeBlock pagingStage(String pageableProvider, boolean slice) { + + Builder builder = CodeBlock.builder(); + + builder.add(sortingStage(pageableProvider + ".getSort()")); + + builder.beginControlFlow("if ($L.isPaged())", pageableProvider); + builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider); + builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class, + pageableProvider); + builder.endControlFlow(); + if (slice) { + builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"), + Aggregation.class, pageableProvider); + } else { + builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class, + pageableProvider); + } + builder.endControlFlow(); + + return builder.build(); + } + + private CodeBlock limitingStage(String limitProvider) { + + Builder builder = CodeBlock.builder(); + + builder.beginControlFlow("if ($L.isLimited())", limitProvider); + builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class, + limitProvider); + builder.endControlFlow(); + + return builder.build(); + } + + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/BuilderStyleSnippet.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/BuilderStyleSnippet.java new file mode 100644 index 000000000..42627839e --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/BuilderStyleSnippet.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; + +/** + * @author Christoph Strobl + */ +public class BuilderStyleSnippet implements Snippet { + + private final String targetVariableName; + private final String methodName; + private final Snippet argumentValue; + + BuilderStyleSnippet(String targetVariableName, String methodName, Snippet argumentValue) { + this.targetVariableName = targetVariableName; + this.methodName = methodName; + this.argumentValue = argumentValue; + } + + @Override + public CodeBlock code() { + + Builder builder = CodeBlock.builder(); + builder.add("$1L = $1L.$2L($3L);\n", targetVariableName, methodName, argumentValue.code()); + return builder.build(); + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java index 1d009f308..74f11a1ae 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java @@ -18,6 +18,7 @@ package org.springframework.data.mongodb.repository.aot; import java.util.Optional; import org.jspecify.annotations.NullUnmarked; +import org.springframework.core.ResolvableType; import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.repository.query.MongoQueryExecution.DeleteExecution; @@ -35,66 +36,68 @@ import org.springframework.util.ObjectUtils; */ class DeleteBlocks { - @NullUnmarked - static class DeleteExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String queryVariableName; - - DeleteExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.queryMethod = queryMethod; - } - - DeleteExecutionCodeBlockBuilder referencing(String queryVariableName) { - - this.queryVariableName = queryVariableName; - return this; - } - - CodeBlock build() { - - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); - - Class domainType = context.getRepositoryInformation().getDomainType(); - boolean isProjecting = context.getActualReturnType() != null - && !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType()); - - Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType; - - builder.add("\n"); - builder.addStatement("$1T<$2T> $3L = $4L.remove($2T.class)", ExecutableRemove.class, domainType, - context.localVariable("remover"), mongoOpsRef); - - DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; - if (!queryMethod.isCollectionQuery()) { - if (!ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())) { - type = DeleteExecution.Type.FIND_AND_REMOVE_ONE; - } else { - type = DeleteExecution.Type.ALL; - } - } - - actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType()) - ? TypeName.get(context.getMethod().getReturnType()) - : queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType; - - if (ClassUtils.isVoidType(context.getMethod().getReturnType())) { - builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"), - DeleteExecution.Type.class, type.name(), queryVariableName); - } else if (context.getMethod().getReturnType() == Optional.class) { - builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class, - actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class, - type.name(), queryVariableName); - } else { - builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, - context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName); - } - - return builder.build(); - } - } + @NullUnmarked + static class DeleteExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String queryVariableName; + + DeleteExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + DeleteExecutionCodeBlockBuilder referencing(String queryVariableName) { + + this.queryVariableName = queryVariableName; + return this; + } + + CodeBlock build() { + + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); + + Class domainType = context.getRepositoryInformation().getDomainType(); + boolean isProjecting = context.getActualReturnType() != null + && !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType()); + + Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType; + + builder.add("\n"); + VariableSnippet remover = Snippet.declare(builder) + .variable(ResolvableType.forClassWithGenerics(ExecutableRemove.class, domainType), + context.localVariable("remover")) + .as("$L.remove($T.class)", mongoOpsRef, domainType); + + DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; + if (!queryMethod.isCollectionQuery()) { + if (!ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())) { + type = DeleteExecution.Type.FIND_AND_REMOVE_ONE; + } else { + type = DeleteExecution.Type.ALL; + } + } + + actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType()) + ? TypeName.get(context.getMethod().getReturnType()) + : queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType; + + if (ClassUtils.isVoidType(context.getMethod().getReturnType())) { + builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, remover.getVariableName(), + DeleteExecution.Type.class, type.name(), queryVariableName); + } else if (context.getMethod().getReturnType() == Optional.class) { + builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class, + actualReturnType, DeleteExecution.class, remover.getVariableName(), DeleteExecution.Type.class, type.name(), + queryVariableName); + } else { + builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, + context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName); + } + + return builder.build(); + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/ExpressionSnippet.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/ExpressionSnippet.java new file mode 100644 index 000000000..a803704be --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/ExpressionSnippet.java @@ -0,0 +1,53 @@ +/* + * Copyright 2025-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import org.springframework.javapoet.CodeBlock; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class ExpressionSnippet implements Snippet { + + private final CodeBlock block; + private final boolean requiresEvaluation; + + public ExpressionSnippet(CodeBlock block) { + this(block, false); + } + + public ExpressionSnippet(Snippet block) { + this(block.code(), block instanceof ExpressionSnippet eb && eb.requiresEvaluation()); + } + + public ExpressionSnippet(CodeBlock block, boolean requiresEvaluation) { + this.block = block; + this.requiresEvaluation = requiresEvaluation; + } + + public static ExpressionSnippet empty() { + return new ExpressionSnippet(CodeBlock.builder().build()); + } + + public boolean requiresEvaluation() { + return requiresEvaluation; + } + + public CodeBlock code() { + return block; + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java index b94f55adc..8f2df3e4c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java @@ -50,38 +50,37 @@ class GeoBlocks { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("\n"); - String locationParameterName = context.getParameterName(queryMethod.getParameters().getNearIndex()); - - builder.addStatement("$1T $2L = $1T.near($3L)", NearQuery.class, variableName, locationParameterName); + VariableSnippet query = Snippet.declare(builder).variable(NearQuery.class, variableName).as("$T.near($L)", + NearQuery.class, context.getParameterName(queryMethod.getParameters().getNearIndex())); if (queryMethod.getParameters().getRangeIndex() != -1) { - String rangeParametername = context.getParameterName(queryMethod.getParameters().getRangeIndex()); - String minVarName = context.localVariable("min"); - String maxVarName = context.localVariable("max"); + String rangeParameter = context.getParameterName(queryMethod.getParameters().getRangeIndex()); - builder.beginControlFlow("if($L.getLowerBound().isBounded())", rangeParametername); - builder.addStatement("$1T $2L = $3L.getLowerBound().getValue().get()", Distance.class, minVarName, - rangeParametername); - builder.addStatement("$1L.minDistance($2L).in($2L.getMetric())", variableName, minVarName); + builder.beginControlFlow("if($L.getLowerBound().isBounded())", rangeParameter); + VariableSnippet min = Snippet.declare(builder).variable(Distance.class, context.localVariable("min")) + .as("$L.getLowerBound().getValue().get()", rangeParameter); + builder.addStatement("$1L.minDistance($2L).in($2L.getMetric())", query.getVariableName(), + min.getVariableName()); builder.endControlFlow(); - builder.beginControlFlow("if($L.getUpperBound().isBounded())", rangeParametername); - builder.addStatement("$1T $2L = $3L.getUpperBound().getValue().get()", Distance.class, maxVarName, - rangeParametername); - builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, maxVarName); + builder.beginControlFlow("if($L.getUpperBound().isBounded())", rangeParameter); + VariableSnippet max = Snippet.declare(builder).variable(Distance.class, context.localVariable("max")) + .as("$L.getUpperBound().getValue().get()", rangeParameter); + builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", query.getVariableName(), + max.getVariableName()); builder.endControlFlow(); } else { - String distanceParametername = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex()); - builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, distanceParametername); + String distanceParameter = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex()); + builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", query.code(), distanceParameter); } if (context.getPageableParameterName() != null) { - builder.addStatement("$L.with($L)", variableName, context.getPageableParameterName()); + builder.addStatement("$L.with($L)", query.code(), context.getPageableParameterName()); } - MongoCodeBlocks.appendReadPreference(context, builder, variableName); + MongoCodeBlocks.appendReadPreference(context, builder, query.getVariableName()); return builder.build(); } @@ -115,29 +114,29 @@ class GeoBlocks { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("\n"); - String executorVar = context.localVariable("nearFinder"); - builder.addStatement("var $L = $L.query($T.class).near($L)", executorVar, - context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), - queryVariableName); + VariableSnippet queryExecutor = Snippet.declare(builder).variable(context.localVariable("nearFinder")).as( + "$L.query($T.class).near($L)", context.fieldNameOf(MongoOperations.class), + context.getRepositoryInformation().getDomainType(), queryVariableName); if (ClassUtils.isAssignable(GeoPage.class, context.getReturnType().getRawClass())) { - String geoResultVar = context.localVariable("geoResult"); - builder.addStatement("var $L = $L.all()", geoResultVar, executorVar); + VariableSnippet geoResult = Snippet.declare(builder).variable(context.localVariable("geoResult")).as("$L.all()", + queryExecutor.getVariableName()); builder.beginControlFlow("if($L.isUnpaged())", context.getPageableParameterName()); - builder.addStatement("return new $T<>($L)", GeoPage.class, geoResultVar); + builder.addStatement("return new $T<>($L)", GeoPage.class, geoResult.getVariableName()); builder.endControlFlow(); - String pageVar = context.localVariable("resultPage"); - builder.addStatement("var $L = $T.getPage($L.getContent(), $L, () -> $L.count())", pageVar, - PageableExecutionUtils.class, geoResultVar, context.getPageableParameterName(), executorVar); - builder.addStatement("return new $T<>($L, $L, $L.getTotalElements())", GeoPage.class, geoResultVar, - context.getPageableParameterName(), pageVar); + VariableSnippet resultPage = Snippet.declare(builder).variable(context.localVariable("resultPage")).as( + "$T.getPage($L.getContent(), $L, () -> $L.count())", PageableExecutionUtils.class, + geoResult.getVariableName(), context.getPageableParameterName(), queryExecutor.getVariableName()); + + builder.addStatement("return new $T<>($L, $L, $L.getTotalElements())", GeoPage.class, + geoResult.getVariableName(), context.getPageableParameterName(), resultPage.getVariableName()); } else if (ClassUtils.isAssignable(GeoResults.class, context.getReturnType().getRawClass())) { - builder.addStatement("return $L.all()", executorVar); + builder.addStatement("return $L.all()", queryExecutor.getVariableName()); } else { - builder.addStatement("return $L.all().getContent()", executorVar); + builder.addStatement("return $L.all().getContent()", queryExecutor.getVariableName()); } return builder.build(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 4125139bd..3cbb4e856 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -170,6 +170,25 @@ class MongoCodeBlocks { return new GeoNearExecutionCodeBlockBuilder(context, queryMethod); } + static CodeBlock asDocument(String source, Map arguments) { + + Builder builder = CodeBlock.builder(); + if (!StringUtils.hasText(source)) { + builder.add("new $T()", Document.class); + } else if (!containsPlaceholder(source)) { + builder.add("$T.parse($S)", Document.class, source); + } else { + builder.add("bindParameters($S, ", source); + if (containsNamedPlaceholder(source)) { + builder.add(renderArgumentMap(arguments)); + } else { + builder.add(renderArgumentArray(arguments)); + } + builder.add(");\n"); + } + return builder.build(); + } + static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName, Map arguments) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 524c5e8f2..7f4f069a8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -111,7 +111,10 @@ public class MongoRepositoryContributor extends RepositoryContributor { AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method); if (queryMethod.isSearchQuery() || method.isAnnotationPresent(VectorSearch.class)) { - return searchMethodContributor(queryMethod, new SearchInteraction(query.getQuery())); + + VectorSearch vectorSearch = AnnotatedElementUtils.findMergedAnnotation(method, VectorSearch.class); + return searchMethodContributor(queryMethod, new SearchInteraction(getRepositoryInformation().getDomainType(), + vectorSearch, query.getQuery(), queryMethod.getParameters())); } if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1 @@ -233,8 +236,9 @@ public class MongoRepositoryContributor extends RepositoryContributor { String variableName = "search"; - builder.add(new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod) - .usingVariableName(variableName).withFilter(interaction.getFilter()).build()); + builder.add( + new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod, interaction.getSearchPath()) + .usingVariableName(variableName).withFilter(interaction.getFilter()).build()); return builder.build(); }); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java index 7ad0c25b1..37f347069 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -22,7 +22,6 @@ import java.util.Optional; import org.bson.Document; import org.jspecify.annotations.NullUnmarked; -import org.jspecify.annotations.Nullable; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; @@ -215,9 +214,9 @@ class QueryBlocks { if (StringUtils.hasText(source.getQuery().getFieldsString())) { - builder - .add(MongoCodeBlocks.renderExpressionToDocument(source.getQuery().getFieldsString(), "fields", arguments)); - builder.addStatement("$L.setFieldsObject(fields)", queryVariableName); + VariableSnippet fields = Snippet.declare(builder).variable(Document.class, context.localVariable("fields")) + .of(MongoCodeBlocks.asDocument(source.getQuery().getFieldsString(), arguments)); + builder.addStatement("$L.setFieldsObject($L)", queryVariableName, fields.getVariableName()); } String sortParameter = context.getSortParameterName(); @@ -225,8 +224,9 @@ class QueryBlocks { builder.addStatement("$L.with($L)", queryVariableName, sortParameter); } else if (StringUtils.hasText(source.getQuery().getSortString())) { - builder.add(MongoCodeBlocks.renderExpressionToDocument(source.getQuery().getSortString(), "sort", arguments)); - builder.addStatement("$L.setSortObject(sort)", queryVariableName); + VariableSnippet sort = Snippet.declare(builder).variable(Document.class, context.localVariable("sort")) + .of(MongoCodeBlocks.asDocument(source.getQuery().getSortString(), arguments)); + builder.addStatement("$L.setSortObject($L)", queryVariableName, sort.getVariableName()); } String limitParameter = context.getLimitParameterName(); @@ -273,10 +273,10 @@ class QueryBlocks { if (collationAnnotation.isPresent()) { String collationString = collationAnnotation.getString("value"); - if(StringUtils.hasText(collationString)) { + if (StringUtils.hasText(collationString)) { if (!MongoCodeBlocks.containsPlaceholder(collationString)) { builder.addStatement("$L.collation($T.parse($S))", queryVariableName, - org.springframework.data.mongodb.core.query.Collation.class, collationString); + org.springframework.data.mongodb.core.query.Collation.class, collationString); } else { builder.add("$L.collation(collationOf(evaluate($S, ", queryVariableName, collationString); builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); @@ -292,29 +292,28 @@ class QueryBlocks { Builder builder = CodeBlock.builder(); builder.add("\n"); - builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName)); + + Snippet.declare(builder).variable(BasicQuery.class, this.queryVariableName).of(renderExpressionToQuery()); return builder.build(); } - private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) { + private CodeBlock renderExpressionToQuery() { - Builder builder = CodeBlock.builder(); + String source = this.source.getQuery().getQueryString(); if (!StringUtils.hasText(source)) { - - builder.addStatement("$1T $2L = new $1T(new $3T())", BasicQuery.class, variableName, Document.class); - } else if (!MongoCodeBlocks.containsPlaceholder(source)) { - builder.addStatement("$1T $2L = new $1T($3T.parse($4S))", BasicQuery.class, variableName, Document.class, - source); + return CodeBlock.of("new $T(new $T())", BasicQuery.class, Document.class); + } + if (!MongoCodeBlocks.containsPlaceholder(source)) { + return CodeBlock.of("new $T($T.parse($S))", BasicQuery.class, Document.class, source); + } + Builder builder = CodeBlock.builder(); + builder.add("createQuery($S, ", source); + if (MongoCodeBlocks.containsNamedPlaceholder(source)) { + builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); } else { - builder.add("$T $L = createQuery($S, ", BasicQuery.class, variableName, source); - if (MongoCodeBlocks.containsNamedPlaceholder(source)) { - builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); - } else { - builder.add(MongoCodeBlocks.renderArgumentArray(arguments)); - } - builder.add(");\n"); + builder.add(MongoCodeBlocks.renderArgumentArray(arguments)); } - + builder.add(")"); return builder.build(); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java index a94ff1082..906061018 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java @@ -15,26 +15,55 @@ */ package org.springframework.data.mongodb.repository.aot; +import java.lang.reflect.Field; +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import org.bson.Document; +import org.bson.json.JsonMode; +import org.bson.json.JsonWriterSettings; import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.MongoParameters; import org.springframework.data.repository.aot.generate.QueryMetadata; +import org.springframework.util.StringUtils; /** * @author Christoph Strobl */ public class SearchInteraction extends MongoInteraction implements QueryMetadata { - StringQuery filter; + private final Class domainType; + private final StringQuery filter; + private final @Nullable VectorSearch vectorSearch; + private final MongoParameters parameters; + + public SearchInteraction(Class domainType, @Nullable VectorSearch vectorSearch, StringQuery filter, + MongoParameters parameters) { + + this.domainType = domainType; + this.vectorSearch = vectorSearch; - public SearchInteraction(StringQuery filter) { this.filter = filter; + this.parameters = parameters; } public StringQuery getFilter() { return filter; } + @Nullable + String getIndexName() { + return vectorSearch != null ? vectorSearch.indexName() : null; + } + + public MongoParameters getParameters() { + return parameters; + } + @Override InteractionType getExecutionType() { return InteractionType.AGGREGATION; @@ -43,6 +72,62 @@ public class SearchInteraction extends MongoInteraction implements QueryMetadata @Override public Map serialize() { - return Map.of("FIXME", "please!"); + Map serialized = new LinkedHashMap<>(); + + if (vectorSearch != null && StringUtils.hasText(vectorSearch.indexName())) { + serialized.put("index", vectorSearch.indexName()); + } + + serialized.put("path", getSearchPath()); + + if (vectorSearch.searchType().equals(SearchType.ENN)) { + serialized.put("exact", true); + } + + if (StringUtils.hasText(filter.getQueryString())) { + serialized.put("filter", filter.getQueryString()); + } + + String limit = limitParameter(); + if (StringUtils.hasText(limit)) { + serialized.put("limit", limit); + } + + if (StringUtils.hasText(vectorSearch.numCandidates())) { + serialized.put("numCandidates", vectorSearch.numCandidates()); + } else if (StringUtils.hasText(limit)) { + serialized.put("numCandidates", limit + " * 20"); + } + + serialized.put("queryVector", "?" + parameters.getVectorIndex()); + + return Map.of("pipeline", List.of(new Document("$vectorSearch", serialized) + .toJson(JsonWriterSettings.builder().outputMode(JsonMode.RELAXED).build()).replaceAll("\\\"", "'"))); + } + + private @Nullable String limitParameter() { + + if (parameters.hasLimitParameter()) { + return "?" + parameters.getLimitIndex(); + } else if (StringUtils.hasText(vectorSearch.limit())) { + return vectorSearch.limit(); + } + return null; + } + + public String getSearchPath() { + + if (vectorSearch != null && StringUtils.hasText(vectorSearch.path())) { + return vectorSearch.path(); + } + + Field[] declaredFields = domainType.getDeclaredFields(); + for (Field field : declaredFields) { + if (Vector.class.isAssignableFrom(field.getType())) { + return field.getName(); + } + } + + throw new IllegalArgumentException("No vector search path found for type %s".formatted(domainType)); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/Snippet.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/Snippet.java new file mode 100644 index 000000000..7f62f0cdd --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/Snippet.java @@ -0,0 +1,224 @@ +/* + * Copyright 2025-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.function.Function; + +import org.jspecify.annotations.Nullable; +import org.springframework.core.ResolvableType; +import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleBuilder.BuilderStyleMethodArgumentBuilder; +import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleVariableBuilder.BuilderStyleVariableBuilderImpl; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +interface Snippet { + + CodeBlock code(); + + default boolean isEmpty() { + return code().isEmpty(); + } + + default void appendTo(CodeBlock.Builder builder) { + if (!isEmpty()) { + builder.add(code()); + } + } + + default T as(Function transformer) { + return transformer.apply(this); + } + + default Snippet wrap(String prefix, String suffix) { + return wrap("%s$L%s".formatted(prefix, suffix)); + } + + default Snippet wrap(CodeBlock prefix, CodeBlock suffix) { + return new Snippet() { + + @Override + public CodeBlock code() { + return CodeBlock.builder().add(prefix).add(Snippet.this.code()).add(suffix).build(); + } + }; + } + + default Snippet wrap(String statement) { + return new Snippet() { + + @Override + public CodeBlock code() { + return CodeBlock.of(statement, Snippet.this.code()); + } + }; + } + + static Snippet just(CodeBlock codeBlock) { + return new Snippet() { + @Override + public CodeBlock code() { + return codeBlock; + } + }; + } + + static ContextualSnippetBuilder declare(CodeBlock.Builder builder) { + + return new ContextualSnippetBuilder() { + + @Override + public VariableBuilder variable(String variableName) { + return VariableSnippet.variable(variableName).targeting(builder); + } + + @Override + public VariableBuilder variable(Class type, String variableName) { + return VariableSnippet.variable(type, variableName).targeting(builder); + } + + @Override + public VariableBuilder variable(ResolvableType resolvableType, String variableName) { + return VariableSnippet.variable(resolvableType, variableName).targeting(builder); + } + + @Override + public BuilderStyleVariableBuilder variableBuilder(String variableName) { + return new BuilderStyleVariableBuilderImpl(builder, null, variableName); + } + + @Override + public BuilderStyleVariableBuilder variableBuilder(Class type, String variableName) { + return variableBuilder(ResolvableType.forClass(type), variableName); + } + + @Override + public BuilderStyleVariableBuilder variableBuilder(ResolvableType resolvableType, String variableName) { + return new BuilderStyleVariableBuilderImpl(builder, resolvableType, variableName); + } + }; + } + + interface ContextualSnippetBuilder { + + VariableBuilder variable(String variableName); + + VariableBuilder variable(Class type, String variableName); + + VariableBuilder variable(ResolvableType resolvableType, String variableName); + + BuilderStyleVariableBuilder variableBuilder(String variableName); + + BuilderStyleVariableBuilder variableBuilder(Class type, String variableName); + + BuilderStyleVariableBuilder variableBuilder(ResolvableType resolvableType, String variableName); + } + + interface VariableBuilder { + + default VariableSnippet as(String declaration, Object... args) { + return of(CodeBlock.of(declaration, args)); + } + + VariableSnippet of(CodeBlock codeBlock); + } + + interface BuilderStyleVariableBuilder { + + default BuilderStyleBuilder as(String declaration, Object... args) { + return of(CodeBlock.of(declaration, args)); + } + + BuilderStyleBuilder of(CodeBlock codeBlock); + + class BuilderStyleVariableBuilderImpl + implements BuilderStyleVariableBuilder, BuilderStyleBuilder, BuilderStyleMethodArgumentBuilder { + + Builder targetBuilder; + @Nullable ResolvableType type; + String targetVariableName; + @Nullable String targetMethodName; + @Nullable VariableSnippet variableSnippet; + + public BuilderStyleVariableBuilderImpl(Builder targetBuilder, @Nullable ResolvableType type, + String targetVariableName) { + this.targetBuilder = targetBuilder; + this.type = type; + this.targetVariableName = targetVariableName; + } + + @Override + public BuilderStyleBuilder as(String declaration, Object... args) { + + if (type != null) { + this.variableSnippet = Snippet.declare(targetBuilder).variable(type, targetVariableName).as(declaration, args); + } else { + this.variableSnippet = Snippet.declare(targetBuilder).variable(targetVariableName).as(declaration, args); + } + return this; + } + + @Override + public BuilderStyleBuilder of(CodeBlock codeBlock) { + if (type != null) { + this.variableSnippet = Snippet.declare(targetBuilder).variable(type, targetVariableName).of(codeBlock); + } else { + this.variableSnippet = Snippet.declare(targetBuilder).variable(targetVariableName).of(codeBlock); + } + return this; + } + + @Override + public BuilderStyleMethodArgumentBuilder call(String methodName) { + this.targetMethodName = methodName; + return this; + } + + @Override + public BuilderStyleBuilder with(Snippet snippet) { + new BuilderStyleSnippet(targetVariableName, targetMethodName, snippet).appendTo(targetBuilder); + return this; + } + + @Override + public VariableSnippet variable() { + return this.variableSnippet; + } + } + } + + interface BuilderStyleBuilder { + + BuilderStyleMethodArgumentBuilder call(String methodName); + VariableSnippet variable(); + + interface BuilderStyleMethodArgumentBuilder { + default BuilderStyleBuilder with(String statement, Object... args) { + return with(CodeBlock.of(statement, args)); + } + + default BuilderStyleBuilder with(CodeBlock codeBlock) { + return with(Snippet.just(codeBlock)); + } + + BuilderStyleBuilder with(Snippet snippet); + } + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java index e4061c771..759c2a611 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java @@ -1,19 +1,3 @@ -/* - * Copyright 2025. the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - /* * Copyright 2025 the original author or authors. * @@ -21,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -35,6 +19,7 @@ import java.util.LinkedHashMap; import java.util.Map; import org.jspecify.annotations.NullUnmarked; +import org.springframework.core.ResolvableType; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.query.BasicUpdate; @@ -48,7 +33,6 @@ import org.springframework.util.NumberUtils; /** * @author Christoph Strobl - * @since 2025/06 */ class UpdateBlocks { @@ -87,22 +71,26 @@ class UpdateBlocks { String updateReference = updateVariableName; Class domainType = context.getRepositoryInformation().getDomainType(); - builder.addStatement("$1T<$2T> $3L = $4L.update($2T.class)", ExecutableUpdate.class, domainType, - context.localVariable("updater"), mongoOpsRef); + VariableSnippet updater = Snippet.declare(builder) + .variable(ResolvableType.forClassWithGenerics(ExecutableUpdate.class, domainType), + context.localVariable("updater")) + .as("$L.update($T.class)", mongoOpsRef, domainType); Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); if (ReflectionUtils.isVoid(returnType)) { - builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName, + builder.addStatement("$L.matching($L).apply($L).all()", updater.getVariableName(), queryVariableName, updateReference); } else if (ClassUtils.isAssignable(Long.class, returnType)) { - builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", - context.localVariable("updater"), queryVariableName, updateReference); + builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", updater.getVariableName(), + queryVariableName, updateReference); } else { - builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class, - context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName, - updateReference); + + VariableSnippet modifiedCount = Snippet.declare(builder) + .variable(Long.class, context.localVariable("modifiedCount")) + .as("$L.matching($L).apply($L).all().getModifiedCount()", updater.getVariableName(), queryVariableName, + updateReference); builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class, - context.localVariable("modifiedCount"), returnType); + modifiedCount.getVariableName(), returnType); } return builder.build(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VariableSnippet.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VariableSnippet.java new file mode 100644 index 000000000..e3d67049a --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VariableSnippet.java @@ -0,0 +1,119 @@ +/* + * Copyright 2025-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import org.jspecify.annotations.Nullable; +import org.springframework.core.ResolvableType; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.javapoet.TypeName; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class VariableSnippet extends ExpressionSnippet { + + private final String variableName; + private final @Nullable TypeName typeName; + + public VariableSnippet(String variableName, Snippet delegate) { + this((TypeName) null, variableName, delegate); + } + + public VariableSnippet(Class typeName, String variableName, Snippet delegate) { + this(TypeName.get(typeName), variableName, delegate); + } + + public VariableSnippet(@Nullable TypeName typeName, String variableName, Snippet delegate) { + super(delegate); + this.typeName = typeName; + this.variableName = variableName; + } + + static VariableBuilderImp variable(String name) { + return new VariableBuilderImp(null, name); + } + + static VariableBuilderImp variable(Class typeName, String name) { + return variable(TypeName.get(typeName), name); + } + + static VariableBuilderImp variable(ResolvableType resolvableType, String name) { + return variable(TypeName.get(resolvableType.getType()), name); + } + + static VariableBuilderImp variable(TypeName typeName, String name) { + return new VariableBuilderImp(typeName, name); + } + + static class VariableBuilderImp implements VariableBuilder { + + private @Nullable TypeName typeName; + private String variableName; + + private CodeBlock.@Nullable Builder target; + + VariableBuilderImp(@Nullable TypeName typeName, String variableName) { + this.typeName = typeName; + this.variableName = variableName; + } + + @Override + public VariableSnippet of(CodeBlock codeBlock) { + + VariableSnippet variableSnippet = new VariableSnippet(typeName, variableName, Snippet.just(codeBlock)); + if (target != null) { + variableSnippet.renderDeclaration(target); + } + return variableSnippet; + } + + VariableBuilderImp targeting(@Nullable Builder target) { + this.target = target; + return this; + } + } + + @Override + public CodeBlock code() { + return CodeBlock.of("$L", variableName); + } + + public String getVariableName() { + return variableName; + } + + void renderDeclaration(CodeBlock.Builder builder) { + if (typeName != null) { + builder.addStatement("$T $L = $L", typeName, variableName, super.code()); + } else { + builder.addStatement("var $L = $L", variableName, super.code()); + } + } + + static VariableBlockBuilder create(Snippet snippet) { + return variableName -> create(variableName, snippet); + } + + static VariableSnippet create(String variableName, Snippet snippet) { + return new VariableSnippet(variableName, snippet); + } + + interface VariableBlockBuilder { + VariableSnippet variableName(String variableName); + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java index 3efdc080b..c9d74edc6 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java @@ -15,7 +15,6 @@ */ package org.springframework.data.mongodb.repository.aot; -import java.lang.reflect.Field; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -25,15 +24,14 @@ import org.springframework.core.annotation.MergedAnnotation; import org.springframework.data.domain.Limit; import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.Sort; -import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperation; -import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleBuilder; import org.springframework.data.mongodb.repository.query.MongoQueryExecution.VectorSearchExecution; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; @@ -55,11 +53,14 @@ class VectorSearchBocks { private String searchQueryVariableName; private StringQuery filter; private final Map arguments; + private final String searchPath; - VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod, + String searchPath) { this.context = context; this.queryMethod = queryMethod; + this.searchPath = searchPath; this.arguments = new LinkedHashMap<>(); context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); } @@ -77,113 +78,56 @@ class VectorSearchBocks { String vectorParameterName = context.getVectorParameterName(); MergedAnnotation annotation = context.getAnnotation(VectorSearch.class); - String searchPath = annotation.getString("path"); String indexName = annotation.getString("indexName"); - String numCandidates = annotation.getString("numCandidates"); SearchType searchType = annotation.getEnum("searchType", SearchType.class); - String limit = annotation.getString("limit"); - if (!StringUtils.hasText(searchPath)) { // FIXME: somehow duplicate logic of AnnotatedQueryFactory + ExpressionSnippet limit = getLimitExpression(); - Field[] declaredFields = context.getRepositoryInformation().getDomainType().getDeclaredFields(); - for (Field field : declaredFields) { - if (Vector.class.isAssignableFrom(field.getType())) { - searchPath = field.getName(); - break; - } - } + if (limit.requiresEvaluation() && !StringUtils.hasText(annotation.getString("numCandidates")) + && (searchType == VectorSearchOperation.SearchType.ANN + || searchType == VectorSearchOperation.SearchType.DEFAULT)) { + VariableSnippet variableBlock = limit.as(VariableSnippet::create) + .variableName(context.localVariable("limitToUse")); + variableBlock.renderDeclaration(builder); + limit = variableBlock; } - String vectorSearchVar = context.localVariable("$vectorSearch"); - builder.add("$T $L = $T.vectorSearch($S).path($S).vector($L)", VectorSearchOperation.class, vectorSearchVar, - Aggregation.class, indexName, searchPath, vectorParameterName); - - if (StringUtils.hasText(context.getLimitParameterName())) { - builder.add(".limit($L);\n", context.getLimitParameterName()); - } else if (filter.isLimited()) { - builder.add(".limit($L);\n", filter.getLimit()); - } else if (StringUtils.hasText(limit)) { - if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { - builder.add(".limit("); - builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments)); - builder.add(");\n"); - } else { - builder.add(".limit($L);\n", limit); - } - } else { - builder.add(".limit($T.unlimited());\n", Limit.class); - } + BuilderStyleBuilder vectorSearchOperationBuilder = Snippet.declare(builder) + .variableBuilder(VectorSearchOperation.class, context.localVariable("$vectorSearch")) + .as("$T.vectorSearch($S).path($S).vector($L).limit($L)", Aggregation.class, indexName, searchPath, + vectorParameterName, limit.code()); if (!searchType.equals(SearchType.DEFAULT)) { - builder.addStatement("$1L = $1L.searchType($2T.$3L)", vectorSearchVar, SearchType.class, searchType.name()); + vectorSearchOperationBuilder.call("searchType").with("$T.$L", SearchType.class, searchType.name()); } - if (StringUtils.hasText(numCandidates)) { - builder.add("$1L = $1L.numCandidates(", vectorSearchVar); - builder.add(MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, arguments)); - builder.add(");\n"); - } else if (searchType == VectorSearchOperation.SearchType.ANN - || searchType == VectorSearchOperation.SearchType.DEFAULT) { - - builder.add( - "// MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return\n"); - if (StringUtils.hasText(context.getLimitParameterName())) { - builder.addStatement("$1L = $1L.numCandidates($2L.max() * 20)", vectorSearchVar, - context.getLimitParameterName()); - } else if (StringUtils.hasText(limit)) { - if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { - - builder.add("$1L = $1L.numCandidates((", vectorSearchVar); - builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments)); - builder.add(") * 20);\n"); - } else { - builder.addStatement("$1L = $1L.numCandidates($2L * 20)", vectorSearchVar, limit); - } - } else { - builder.addStatement("$1L = $1L.numCandidates($2L)", vectorSearchVar, filter.getLimit() * 20); - } + ExpressionSnippet numCandidates = getNumCandidatesExpression(searchType, limit); + if (!numCandidates.isEmpty()) { + vectorSearchOperationBuilder.call("numCandidates").with(numCandidates); } - builder.addStatement("$1L = $1L.withSearchScore(\"__score__\")", vectorSearchVar); - if (StringUtils.hasText(context.getScoreParameterName())) { + vectorSearchOperationBuilder.call("withSearchScore").with("\"__score__\""); - String scoreCriteriaVar = context.localVariable("criteria"); - builder.addStatement("$1L = $1L.withFilterBySore($2L -> { $2L.gt($3L.getValue()); })", vectorSearchVar, - scoreCriteriaVar, context.getScoreParameterName()); + if (StringUtils.hasText(context.getScoreParameterName())) { + vectorSearchOperationBuilder.call("withFilterBySore").with("$1L -> { $1L.gt($2L.getValue()); }", + context.localVariable("criteria"), context.getScoreParameterName()); } else if (StringUtils.hasText(context.getScoreRangeParameterName())) { - builder.addStatement("$1L = $1L.withFilterBySore(scoreBetween($2L.getLowerBound(), $2L.getUpperBound()))", - vectorSearchVar, context.getScoreRangeParameterName()); + vectorSearchOperationBuilder.call("withFilterBySore") + .with("scoreBetween($1L.getLowerBound(), $1L.getUpperBound())", context.getScoreRangeParameterName()); } - if (StringUtils.hasText(filter.getQueryString())) { - - String filterVar = context.localVariable("filter"); - builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter") - .filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery()); - builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar); - builder.add("\n"); - } + VariableSnippet vectorSearchOperation = vectorSearchOperationBuilder.variable(); + getFilter(vectorSearchOperation.getVariableName()).appendTo(builder); + VariableSnippet sortStage = getSort().as(VariableSnippet::create).variableName(context.localVariable("$sort")); + sortStage.renderDeclaration(builder); - String sortStageVar = context.localVariable("$sort"); - if(filter.isSorted()) { - - builder.add("$T $L = (_ctx) -> {\n", AggregationOperation.class, sortStageVar); - builder.indent(); - - builder.addStatement("$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class)", Document.class, filter.getSortString(), context.getActualReturnType().getType()); - builder.addStatement("return new $T($S, _mappedSort.append(\"__score__\", -1))", Document.class, "$sort"); - builder.unindent(); - builder.add("};"); - - } else { - builder.addStatement("var $L = $T.sort($T.Direction.DESC, $S)", sortStageVar, Aggregation.class, Sort.class, "__score__"); - } builder.add("\n"); - builder.addStatement("$1T $2L = new $1T($3T.of($4L, $5L))", AggregationPipeline.class, searchQueryVariableName, - List.class, vectorSearchVar, sortStageVar); + VariableSnippet aggregationPipeline = Snippet.declare(builder) + .variable(AggregationPipeline.class, searchQueryVariableName).as("new $T($T.of($L, $L))", + AggregationPipeline.class, List.class, vectorSearchOperation.getVariableName(), sortStage.code()); String scoringFunctionVar = context.localVariable("scoringFunction"); builder.add("$1T $2L = ", ScoringFunction.class, scoringFunctionVar); @@ -199,13 +143,108 @@ class VectorSearchBocks { "return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)", VectorSearchExecution.class, context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), TypeInformation.class, - queryMethod.getReturnType().getType(), searchQueryVariableName, scoringFunctionVar); + queryMethod.getReturnType().getType(), aggregationPipeline.getVariableName(), scoringFunctionVar); return builder.build(); } + private ExpressionSnippet getSort() { + + if (!filter.isSorted()) { + return new ExpressionSnippet( + CodeBlock.of("$T.sort($T.Direction.DESC, $S)", Aggregation.class, Sort.class, "__score__")); + } + + Builder builder = CodeBlock.builder(); + + builder.add("($T) (_ctx) -> {\n", AggregationOperation.class); + builder.indent(); + + builder.add("$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class);\n", Document.class, + filter.getSortString(), context.getActualReturnType().getType()); + builder.add("return new $T($S, _mappedSort.append(\"__score__\", -1));\n", Document.class, "$sort"); + builder.unindent(); + builder.add("};"); + + return new ExpressionSnippet(builder.build()); + } + + private Snippet getFilter(String vectorSearchVar) { + + if (!StringUtils.hasText(filter.getQueryString())) { + return ExpressionSnippet.empty(); + } + + Builder builder = CodeBlock.builder(); + String filterVar = context.localVariable("filter"); + builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter") + .filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery()); + builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar); + builder.add("\n"); + + return new ExpressionSnippet(builder.build()); + } + public VectorSearchQueryCodeBlockBuilder withFilter(StringQuery filter) { this.filter = filter; return this; } + + private ExpressionSnippet getNumCandidatesExpression(SearchType searchType, ExpressionSnippet limit) { + + MergedAnnotation annotation = context.getAnnotation(VectorSearch.class); + String numCandidates = annotation.getString("numCandidates"); + + if (StringUtils.hasText(numCandidates)) { + if (MongoCodeBlocks.containsPlaceholder(numCandidates) || MongoCodeBlocks.containsExpression(numCandidates)) { + return new ExpressionSnippet( + MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, arguments), true); + } else { + return new ExpressionSnippet(CodeBlock.of("$L", numCandidates)); + } + } + + if (searchType == VectorSearchOperation.SearchType.ANN + || searchType == VectorSearchOperation.SearchType.DEFAULT) { + + Builder builder = CodeBlock.builder(); + + if (StringUtils.hasText(context.getLimitParameterName())) { + builder.add("$L.max() * 20", context.getLimitParameterName()); + } else if (filter.isLimited()) { + builder.add("$L", filter.getLimit() * 20); + } else { + builder.add("$L * 20", limit.code()); + } + + return new ExpressionSnippet(builder.build()); + } + + return ExpressionSnippet.empty(); + } + + private ExpressionSnippet getLimitExpression() { + + if (StringUtils.hasText(context.getLimitParameterName())) { + return new ExpressionSnippet(CodeBlock.of("$L", context.getLimitParameterName())); + } + + if (filter.isLimited()) { + return new ExpressionSnippet(CodeBlock.of("$L", filter.getLimit())); + } + + MergedAnnotation annotation = context.getAnnotation(VectorSearch.class); + String limit = annotation.getString("limit"); + + if (StringUtils.hasText(limit)) { + + if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { + return new ExpressionSnippet(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments), + true); + } else { + return new ExpressionSnippet(CodeBlock.of("$L", limit)); + } + } + return new ExpressionSnippet(CodeBlock.of("$T.unlimited()", Limit.class)); + } } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index d8de601d4..5395dba2b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -332,11 +332,11 @@ public class QueryMethodContributionUnitTests { Range.class); assertThat(methodSpec.toString()) // - .containsSubsequence( - "Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(", + .containsSubsequence("var limitToUse = ", "evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance)") - .containsSubsequence("$vectorSearch.numCandidates(", - "evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance))) * 20)"); + .contains( + "Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(limitToUse)") + .contains("$vectorSearch.numCandidates(limitToUse * 20)"); } @Test @@ -358,7 +358,8 @@ public class QueryMethodContributionUnitTests { String.class, Vector.class, Range.class); assertThat(methodSpec.toString()) // - .containsSubsequence("AggregationOperation $sort = (_ctx) -> {", // + .containsSubsequence("var $sort = ", // + "(_ctx) -> {", // "_mappedSort = _ctx.getMappedObject(", // "Document.parse(\"{'firstname':{'$numberInt':'1'}}\")", // "Document(\"$sort\", _mappedSort.append(\"__score__\", -1))"); diff --git a/spring-data-mongodb/src/test/resources/logback.xml b/spring-data-mongodb/src/test/resources/logback.xml index d0907937f..94130f8d6 100644 --- a/spring-data-mongodb/src/test/resources/logback.xml +++ b/spring-data-mongodb/src/test/resources/logback.xml @@ -20,7 +20,9 @@ +