diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java index 1537b6c72..df635fcd2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java @@ -35,7 +35,10 @@ import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; /** + * Support class for MongoDB AOT repository fragments. + * * @author Christoph Strobl + * @since 5.0 */ public class MongoAotRepositoryFragmentSupport { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 8338ffe6b..7afa2a5f5 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -246,20 +246,24 @@ class MongoCodeBlocks { builder.add("\n"); String updateReference = updateVariableName; - builder.addStatement("$T<$T> updater = $L.update($T.class)", ExecutableUpdate.class, - context.getRepositoryInformation().getDomainType(), mongoOpsRef, - context.getRepositoryInformation().getDomainType()); + Class domainType = context.getRepositoryInformation().getDomainType(); + builder.addStatement("$T<$T> $L = $L.update($T.class)", ExecutableUpdate.class, domainType, + context.localVariable("updater"), mongoOpsRef, domainType); Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); if (ReflectionUtils.isVoid(returnType)) { - builder.addStatement("updater.matching($L).apply($L).all()", queryVariableName, updateReference); + builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName, + updateReference); } else if (ClassUtils.isAssignable(Long.class, returnType)) { - builder.addStatement("return updater.matching($L).apply($L).all().getModifiedCount()", queryVariableName, + builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", + context.localVariable("updater"), queryVariableName, updateReference); } else { - builder.addStatement("$T modifiedCount = updater.matching($L).apply($L).all().getModifiedCount()", Long.class, + builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class, + context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName, updateReference); - builder.addStatement("return $T.convertNumberToTargetClass(modifiedCount, $T.class)", NumberUtils.class, + builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class, + context.localVariable("modifiedCount"), returnType); } @@ -314,24 +318,29 @@ class MongoCodeBlocks { Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - builder.addStatement("$T results = $L.aggregate($L, $T.class)", AggregationResults.class, mongoOpsRef, + builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, + context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); if (!queryMethod.isCollectionQuery()) { builder.addStatement( - "return $T.<$T>firstElement(convertSimpleRawResults($T.class, results.getMappedResults()))", - CollectionUtils.class, returnType, returnType); + "return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))", + CollectionUtils.class, returnType, returnType, context.localVariable("results")); } else { - builder.addStatement("return convertSimpleRawResults($T.class, results.getMappedResults())", returnType); + builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, + context.localVariable("results")); } } else { if (queryMethod.isSliceQuery()) { - builder.addStatement("$T results = $L.aggregate($L, $T.class)", AggregationResults.class, mongoOpsRef, + builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, + context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - builder.addStatement("boolean hasNext = results.getMappedResults().size() > $L.getPageSize()", - context.getPageableParameterName()); + builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()", + context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName()); builder.addStatement( - "return new $T<>(hasNext ? results.getMappedResults().subList(0, $L.getPageSize()) : results.getMappedResults(), $L, hasNext)", - SliceImpl.class, context.getPageableParameterName(), context.getPageableParameterName()); + "return new $T<>($L ? $L.getMappedResults().subList(0, $L.getPageSize()) : $L.getMappedResults(), $L, $L)", + SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"), + context.getPageableParameterName(), context.localVariable("results"), context.getPageableParameterName(), + context.localVariable("hasNext")); } else { builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, aggregationVariableName, outputType); @@ -368,18 +377,19 @@ class MongoCodeBlocks { Builder builder = CodeBlock.builder(); boolean isProjecting = context.getReturnedType().isProjecting(); + Class domainType = context.getRepositoryInformation().getDomainType(); Object actualReturnType = isProjecting ? context.getActualReturnType().getType() - : context.getRepositoryInformation().getDomainType(); + : domainType; builder.add("\n"); if (isProjecting) { - builder.addStatement("$T<$T> finder = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType, - mongoOpsRef, context.getRepositoryInformation().getDomainType(), actualReturnType); + builder.addStatement("$T<$T> $L = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType, + context.localVariable("finder"), mongoOpsRef, domainType, actualReturnType); } else { - builder.addStatement("$T<$T> finder = $L.query($T.class)", FindWithQuery.class, actualReturnType, mongoOpsRef, - context.getRepositoryInformation().getDomainType()); + builder.addStatement("$T<$T> $L = $L.query($T.class)", FindWithQuery.class, actualReturnType, + context.localVariable("finder"), mongoOpsRef, domainType); } String terminatingMethod; @@ -395,13 +405,14 @@ class MongoCodeBlocks { } if (queryMethod.isPageQuery()) { - builder.addStatement("return new $T(finder, $L).execute($L)", PagedExecution.class, + builder.addStatement("return new $T($L, $L).execute($L)", PagedExecution.class, context.localVariable("finder"), context.getPageableParameterName(), query.name()); } else if (queryMethod.isSliceQuery()) { - builder.addStatement("return new $T(finder, $L).execute($L)", SlicedExecution.class, - context.getPageableParameterName(), query.name()); + builder.addStatement("return new $T($L, $L).execute($L)", SlicedExecution.class, + context.localVariable("finder"), context.getPageableParameterName(), query.name()); } else { - builder.addStatement("return finder.matching($L).$L", query.name(), terminatingMethod); + builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(), + terminatingMethod); } return builder.build(); @@ -415,7 +426,7 @@ class MongoCodeBlocks { private final MongoQueryMethod queryMethod; private AggregationInteraction source; - private List arguments; + private final List arguments; private String aggregationVariableName; private boolean pipelineOnly; @@ -449,7 +460,7 @@ class MongoCodeBlocks { CodeBlock.Builder builder = CodeBlock.builder(); builder.add("\n"); - String pipelineName = aggregationVariableName + (pipelineOnly ? "" : "Pipeline"); + String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline")); builder.add(pipeline(pipelineName)); if (!pipelineOnly) { @@ -486,8 +497,7 @@ class MongoCodeBlocks { } Builder builder = CodeBlock.builder(); - String stagesVariableName = "stages"; - builder.add(aggregationStages(stagesVariableName, source.stages(), stageCount, arguments)); + builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments)); if (mightBeSorted) { builder.add(sortingStage(sortParameter)); @@ -502,7 +512,7 @@ class MongoCodeBlocks { } builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName, - stagesVariableName); + context.localVariable("stages")); return builder.build(); } @@ -533,7 +543,8 @@ class MongoCodeBlocks { if (!options.isEmpty()) { Builder optionsBuilder = CodeBlock.builder(); - optionsBuilder.add("$T aggregationOptions = $T.builder()\n", AggregationOptions.class, + optionsBuilder.add("$T $L = $T.builder()\n", AggregationOptions.class, + context.localVariable("aggregationOptions"), AggregationOptions.class); optionsBuilder.indent(); for (CodeBlock optionBlock : options) { @@ -544,67 +555,81 @@ class MongoCodeBlocks { optionsBuilder.unindent(); builder.add(optionsBuilder.build()); - builder.addStatement("$L = $L.withOptions(aggregationOptions)", aggregationVariableName, - aggregationVariableName); + builder.addStatement("$L = $L.withOptions($L)", aggregationVariableName, aggregationVariableName, + context.localVariable("aggregationOptions")); } return builder.build(); } - private static CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, + private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, List arguments) { Builder builder = CodeBlock.builder(); builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, stageCount); int stageCounter = 0; + for (String stage : stages) { - String stageName = "stage_%s".formatted(stageCounter++); + String stageName = context.localVariable("stage_%s".formatted(stageCounter++)); builder.add(renderExpressionToDocument(stage, stageName, arguments)); - builder.addStatement("stages.add($L)", stageName); + builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName); } + return builder.build(); } - private static CodeBlock sortingStage(String sortProvider) { + private CodeBlock sortingStage(String sortProvider) { Builder builder = CodeBlock.builder(); - builder.beginControlFlow("if($L.isSorted())", sortProvider); - builder.addStatement("$T sortDocument = new $T()", Document.class, Document.class); - builder.beginControlFlow("for ($T order : $L)", Order.class, sortProvider); - builder.addStatement("sortDocument.append(order.getProperty(), order.isAscending() ? 1 : -1);"); + + builder.beginControlFlow("if ($L.isSorted())", sortProvider); + builder.addStatement("$T $L = new $T()", Document.class, context.localVariable("sortDocument"), Document.class); + builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider); + builder.addStatement("$L.append($L.getProperty(), $L.isAscending() ? 1 : -1);", + context.localVariable("sortDocument"), context.localVariable("order"), context.localVariable("order")); builder.endControlFlow(); - builder.addStatement("stages.add(new $T($S, sortDocument))", Document.class, "$sort"); + builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort", + context.localVariable("sortDocument")); builder.endControlFlow(); + return builder.build(); } - private static CodeBlock pagingStage(String pageableProvider, boolean slice) { + private CodeBlock pagingStage(String pageableProvider, boolean slice) { Builder builder = CodeBlock.builder(); + builder.add(sortingStage(pageableProvider + ".getSort()")); - builder.beginControlFlow("if($L.isPaged())", pageableProvider); - builder.beginControlFlow("if($L.getOffset() > 0)", pageableProvider); - builder.addStatement("stages.add($T.skip($L.getOffset()))", Aggregation.class, pageableProvider); + builder.beginControlFlow("if ($L.isPaged())", pageableProvider); + builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider); + builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class, + pageableProvider); builder.endControlFlow(); if (slice) { - builder.addStatement("stages.add($T.limit($L.getPageSize() + 1))", Aggregation.class, pageableProvider); + builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"), + Aggregation.class, pageableProvider); } else { - builder.addStatement("stages.add($T.limit($L.getPageSize()))", Aggregation.class, pageableProvider); + builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class, + pageableProvider); } builder.endControlFlow(); return builder.build(); } - private static CodeBlock limitingStage(String limitProvider) { + private CodeBlock limitingStage(String limitProvider) { Builder builder = CodeBlock.builder(); - builder.beginControlFlow("if($L.isLimited())", limitProvider); - builder.addStatement("stages.add($T.limit($L.max()))", Aggregation.class, limitProvider); + + builder.beginControlFlow("if ($L.isLimited())", limitProvider); + builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class, + limitProvider); builder.endControlFlow(); + return builder.build(); } + } @NullUnmarked @@ -614,7 +639,7 @@ class MongoCodeBlocks { private final MongoQueryMethod queryMethod; private QueryInteraction source; - private List arguments; + private final List arguments; private String queryVariableName; QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { @@ -697,17 +722,10 @@ class MongoCodeBlocks { builder.addStatement("$T $L = new $T(new $T())", BasicQuery.class, variableName, BasicQuery.class, Document.class); } else if (!containsPlaceholder(source)) { - - String tmpVarName = "%sString".formatted(variableName); - builder.addStatement("String $L = $S", tmpVarName, source); - - builder.addStatement("$T $L = new $T($T.parse($L))", BasicQuery.class, variableName, BasicQuery.class, - Document.class, tmpVarName); + builder.addStatement("$T $L = new $T($T.parse($S))", BasicQuery.class, variableName, BasicQuery.class, + Document.class, source); } else { - - String tmpVarName = "%sString".formatted(variableName); - builder.addStatement("String $L = $S", tmpVarName, source); - builder.addStatement("$T $L = createQuery($L, new $T[]{ $L })", BasicQuery.class, variableName, tmpVarName, + builder.addStatement("$T $L = createQuery($S, new $T[]{ $L })", BasicQuery.class, variableName, source, Object.class, StringUtils.collectionToDelimitedString(arguments, ", ")); } @@ -757,15 +775,9 @@ class MongoCodeBlocks { if (!StringUtils.hasText(source)) { builder.addStatement("$T $L = new $T()", Document.class, variableName, Document.class); } else if (!containsPlaceholder(source)) { - - String tmpVarName = "%sString".formatted(variableName); - builder.addStatement("String $L = $S", tmpVarName, source); - builder.addStatement("$T $L = $T.parse($L)", Document.class, variableName, Document.class, tmpVarName); + builder.addStatement("$T $L = $T.parse($S)", Document.class, variableName, Document.class, source); } else { - - String tmpVarName = "%sString".formatted(variableName); - builder.addStatement("String $L = $S", tmpVarName, source); - builder.addStatement("$T $L = bindParameters($L, new $T[]{ $L })", Document.class, variableName, tmpVarName, + builder.addStatement("$T $L = bindParameters($S, new $T[]{ $L })", Document.class, variableName, source, Object.class, StringUtils.collectionToDelimitedString(arguments, ", ")); } return builder.build(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index def03c797..3ef1f5ac9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -15,13 +15,7 @@ */ package org.springframework.data.mongodb.repository.aot; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationExecutionBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.deleteExecutionBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryExecutionBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*; import java.lang.reflect.Method; import java.util.regex.Pattern; @@ -29,16 +23,16 @@ import java.util.regex.Pattern; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; + import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.AggregationUpdate; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.Update; -import org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.QueryCodeBlockBuilder; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotRepositoryClassBuilder; import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder; -import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata; import org.springframework.data.repository.aot.generate.MethodContributor; import org.springframework.data.repository.aot.generate.RepositoryContributor; import org.springframework.data.repository.config.AotRepositoryContext; @@ -48,7 +42,6 @@ import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.parser.PartTree; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.TypeName; -import org.springframework.javapoet.TypeSpec.Builder; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; @@ -60,7 +53,7 @@ import org.springframework.util.StringUtils; */ public class MongoRepositoryContributor extends RepositoryContributor { - private static final Log logger = LogFactory.getLog(RepositoryContributor.class); + private static final Log logger = LogFactory.getLog(MongoRepositoryContributor.class); private final AotQueryCreator queryCreator; private final MongoMappingContext mappingContext; @@ -73,9 +66,8 @@ public class MongoRepositoryContributor extends RepositoryContributor { } @Override - protected void customizeClass(RepositoryInformation information, AotRepositoryFragmentMetadata metadata, - Builder builder) { - builder.superclass(TypeName.get(MongoAotRepositoryFragmentSupport.class)); + protected void customizeClass(AotRepositoryClassBuilder classBuilder) { + classBuilder.customize(builder -> builder.superclass(TypeName.get(MongoAotRepositoryFragmentSupport.class))); } @Override @@ -85,70 +77,60 @@ public class MongoRepositoryContributor extends RepositoryContributor { constructorBuilder.addParameter("context", TypeName.get(RepositoryFactoryBeanSupport.FragmentCreationContext.class), false); - constructorBuilder.customize((repositoryInformation, builder) -> { + constructorBuilder.customize((builder) -> { builder.addStatement("super(operations, context)"); }); } @Override @SuppressWarnings("NullAway") - protected @Nullable MethodContributor contributeQueryMethod(Method method, - RepositoryInformation repositoryInformation) { + protected @Nullable MethodContributor contributeQueryMethod(Method method) { - MongoQueryMethod queryMethod = new MongoQueryMethod(method, repositoryInformation, getProjectionFactory(), + MongoQueryMethod queryMethod = new MongoQueryMethod(method, getRepositoryInformation(), getProjectionFactory(), mappingContext); - if (backoff(queryMethod)) { - return null; + if (queryMethod.hasAnnotatedAggregation()) { + AggregationInteraction aggregation = new AggregationInteraction(queryMethod.getAnnotatedAggregation()); + return aggregationMethodContributor(queryMethod, aggregation); } - try { - if (queryMethod.hasAnnotatedAggregation()) { - - AggregationInteraction aggregation = new AggregationInteraction(queryMethod.getAnnotatedAggregation()); - return aggregationMethodContributor(queryMethod, aggregation); - } - - QueryInteraction query = createStringQuery(repositoryInformation, queryMethod, - AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount()); + QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod, + AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount()); - if (queryMethod.hasAnnotatedQuery()) { - if (StringUtils.hasText(queryMethod.getAnnotatedQuery()) - && Pattern.compile("[\\?:][#$]\\{.*\\}").matcher(queryMethod.getAnnotatedQuery()).find()) { + if (queryMethod.hasAnnotatedQuery()) { + if (StringUtils.hasText(queryMethod.getAnnotatedQuery()) + && Pattern.compile("[\\?:][#$]\\{.*\\}").matcher(queryMethod.getAnnotatedQuery()).find()) { - if (logger.isDebugEnabled()) { - logger.debug( - "Skipping AOT generation for [%s]. SpEL expressions are not supported".formatted(method.getName())); - } - return MethodContributor.forQueryMethod(queryMethod).metadataOnly(query); + if (logger.isDebugEnabled()) { + logger.debug( + "Skipping AOT generation for [%s]. SpEL expressions are not supported".formatted(method.getName())); } + return MethodContributor.forQueryMethod(queryMethod).metadataOnly(query); } + } - if (query.isDelete()) { - return deleteMethodContributor(queryMethod, query); - } + if (backoff(queryMethod)) { + return null; + } - if (queryMethod.isModifyingQuery()) { + if (query.isDelete()) { + return deleteMethodContributor(queryMethod, query); + } - Update updateSource = queryMethod.getUpdateSource(); - if (StringUtils.hasText(updateSource.value())) { - UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value())); - return updateMethodContributor(queryMethod, update); - } - if (!ObjectUtils.isEmpty(updateSource.pipeline())) { - AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline()); - return aggregationUpdateMethodContributor(queryMethod, update); - } - } + if (queryMethod.isModifyingQuery()) { - return queryMethodContributor(queryMethod, query); - } catch (RuntimeException codeGenerationError) { - if (logger.isErrorEnabled()) { - logger.error("Failed to generate code for [%s] [%s]".formatted(repositoryInformation.getRepositoryInterface(), - method.getName()), codeGenerationError); + Update updateSource = queryMethod.getUpdateSource(); + if (StringUtils.hasText(updateSource.value())) { + UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value())); + return updateMethodContributor(queryMethod, update); + } + if (!ObjectUtils.isEmpty(updateSource.pipeline())) { + AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline()); + return aggregationUpdateMethodContributor(queryMethod, update); } } - return null; + + return queryMethodContributor(queryMethod, query); } @SuppressWarnings("NullAway") @@ -193,7 +175,6 @@ public class MongoRepositoryContributor extends RepositoryContributor { return MethodContributor.forQueryMethod(queryMethod).withMetadata(aggregation).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); - builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName()))); builder.add(aggregationBlockBuilder(context, queryMethod).stages(aggregation) .usingAggregationVariableName("aggregation").build()); @@ -209,15 +190,14 @@ public class MongoRepositoryContributor extends RepositoryContributor { return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); - builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName()))); // update filter - String filterVariableName = update.name(); + String filterVariableName = context.localVariable(update.name()); builder.add(queryBlockBuilder(context, queryMethod).filter(update.getFilter()) .usingQueryVariableName(filterVariableName).build()); // update definition - String updateVariableName = "updateDefinition"; + String updateVariableName = context.localVariable("updateDefinition"); builder.add( updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName).build()); @@ -233,10 +213,9 @@ public class MongoRepositoryContributor extends RepositoryContributor { return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); - builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName()))); // update filter - String filterVariableName = update.name(); + String filterVariableName = context.localVariable(update.name()); QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(update.getFilter()); builder.add(queryCodeBlockBuilder.usingQueryVariableName(filterVariableName).build()); @@ -245,11 +224,12 @@ public class MongoRepositoryContributor extends RepositoryContributor { builder.add(aggregationBlockBuilder(context, queryMethod).stages(update) .usingAggregationVariableName(updateVariableName).pipelineOnly(true).build()); - builder.addStatement("$T aggregationUpdate = $T.from($L.getOperations())", AggregationUpdate.class, + builder.addStatement("$T $L = $T.from($L.getOperations())", AggregationUpdate.class, + context.localVariable("aggregationUpdate"), AggregationUpdate.class, updateVariableName); builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName) - .referencingUpdate("aggregationUpdate").build()); + .referencingUpdate(context.localVariable("aggregationUpdate")).build()); return builder.build(); }); } @@ -260,11 +240,11 @@ public class MongoRepositoryContributor extends RepositoryContributor { return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); - builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName()))); QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(query); - builder.add(queryCodeBlockBuilder.usingQueryVariableName(query.name()).build()); - builder.add(deleteExecutionBlockBuilder(context, queryMethod).referencing(query.name()).build()); + String queryVariableName = context.localVariable(query.name()); + builder.add(queryCodeBlockBuilder.usingQueryVariableName(queryVariableName).build()); + builder.add(deleteExecutionBlockBuilder(context, queryMethod).referencing(queryVariableName).build()); return builder.build(); }); } @@ -275,12 +255,12 @@ public class MongoRepositoryContributor extends RepositoryContributor { return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); - builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName()))); QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(query); - builder.add(queryCodeBlockBuilder.usingQueryVariableName(query.name()).build()); + builder.add(queryCodeBlockBuilder.usingQueryVariableName(context.localVariable(query.name())).build()); builder.add(queryExecutionBlockBuilder(context, queryMethod).forQuery(query).build()); return builder.build(); }); } + } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 94b58200b..d5c388751 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -51,11 +51,14 @@ import org.springframework.util.StringUtils; import com.mongodb.client.MongoClient; /** + * Integration tests for the {@link UserRepository} AOT fragment. + * * @author Christoph Strobl + * @author Mark Paluch */ @ExtendWith(MongoClientExtension.class) @SpringJUnitConfig(classes = MongoRepositoryContributorTests.MongoRepositoryContributorConfiguration.class) -public class MongoRepositoryContributorTests { +class MongoRepositoryContributorTests { private static final String DB_NAME = "aot-repo-tests"; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java index 67017720e..aa069a271 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java @@ -15,9 +15,9 @@ */ package org.springframework.data.mongodb.repository.aot; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.*; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; import example.aot.UserRepository; @@ -27,6 +27,7 @@ import java.nio.charset.StandardCharsets; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -176,7 +177,7 @@ class MongoRepositoryMetadataTests { assertThat(resource.exists()).isTrue(); String json = resource.getContentAsString(StandardCharsets.UTF_8); - System.out.println(json); + assertThatJson(json).inPath("$.methods[?(@.name == 'existsById')].fragment").isArray().first().isObject() .containsEntry("fragment", "org.springframework.data.mongodb.repository.support.SimpleMongoRepository"); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java index 7cc47d956..8100a67a6 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java @@ -38,81 +38,81 @@ import org.springframework.data.repository.core.support.RepositoryComposition; */ class TestMongoAotRepositoryContext implements AotRepositoryContext { - private final StubRepositoryInformation repositoryInformation; - private final Environment environment = new StandardEnvironment(); - - TestMongoAotRepositoryContext(Class repositoryInterface, @Nullable RepositoryComposition composition) { - this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition); - } - - @Override - public ConfigurableListableBeanFactory getBeanFactory() { - return null; - } - - @Override - public TypeIntrospector introspectType(String typeName) { - return null; - } - - @Override - public IntrospectedBeanDefinition introspectBeanDefinition(String beanName) { - return null; - } - - @Override - public String getBeanName() { - return "dummyRepository"; - } - - @Override - public String getModuleName() { - return "MongoDB"; + private final StubRepositoryInformation repositoryInformation; + private final Environment environment = new StandardEnvironment(); + + TestMongoAotRepositoryContext(Class repositoryInterface, @Nullable RepositoryComposition composition) { + this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition); + } + + @Override + public ConfigurableListableBeanFactory getBeanFactory() { + return null; + } + + @Override + public TypeIntrospector introspectType(String typeName) { + return null; + } + + @Override + public IntrospectedBeanDefinition introspectBeanDefinition(String beanName) { + return null; + } + + @Override + public String getBeanName() { + return "dummyRepository"; + } + + @Override + public String getModuleName() { + return "MongoDB"; + } + + @Override + public Set getBasePackages() { + return Set.of("org.springframework.data.dummy.repository.aot"); + } + + @Override + public Set> getIdentifyingAnnotations() { + return Set.of(Document.class); + } + + @Override + public RepositoryInformation getRepositoryInformation() { + return repositoryInformation; + } + + @Override + public Set> getResolvedAnnotations() { + return Set.of(); + } + + @Override + public Set> getResolvedTypes() { + return Set.of(); + } + + public List getRequiredContextFiles() { + return List.of(classFileForType(repositoryInformation.getRepositoryBaseClass())); + } + + static ClassFile classFileForType(Class type) { + + String name = type.getName(); + ClassPathResource cpr = new ClassPathResource(name.replaceAll("\\.", "/") + ".class"); + + try { + return ClassFile.of(name, cpr.getContentAsByteArray()); + } catch (IOException e) { + throw new IllegalArgumentException("Cannot open [%s].".formatted(cpr.getPath())); } + } - @Override - public Set getBasePackages() { - return Set.of("org.springframework.data.dummy.repository.aot"); - } - - @Override - public Set> getIdentifyingAnnotations() { - return Set.of(Document.class); - } - - @Override - public RepositoryInformation getRepositoryInformation() { - return repositoryInformation; - } - - @Override - public Set> getResolvedAnnotations() { - return Set.of(); - } - - @Override - public Set> getResolvedTypes() { - return Set.of(); - } - - public List getRequiredContextFiles() { - return List.of(classFileForType(repositoryInformation.getRepositoryBaseClass())); - } - - static ClassFile classFileForType(Class type) { - - String name = type.getName(); - ClassPathResource cpr = new ClassPathResource(name.replaceAll("\\.", "/") + ".class"); - - try { - return ClassFile.of(name, cpr.getContentAsByteArray()); - } catch (IOException e) { - throw new IllegalArgumentException("Cannot open [%s].".formatted(cpr.getPath())); - } - } - - @Override - public Environment getEnvironment() { - return environment; - } + @Override + public Environment getEnvironment() { + return environment; + } }