Browse Source

Use `LocalVariableNameFactory` to avoid parameter name clashes.

Closes #4965
pull/4976/head
Mark Paluch 7 months ago
parent
commit
d850b9ef06
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 3
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java
  2. 152
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java
  3. 120
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java
  4. 5
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java
  5. 9
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java
  6. 150
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java

3
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java

@ -35,7 +35,10 @@ import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
/** /**
* Support class for MongoDB AOT repository fragments.
*
* @author Christoph Strobl * @author Christoph Strobl
* @since 5.0
*/ */
public class MongoAotRepositoryFragmentSupport { public class MongoAotRepositoryFragmentSupport {

152
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

@ -246,20 +246,24 @@ class MongoCodeBlocks {
builder.add("\n"); builder.add("\n");
String updateReference = updateVariableName; String updateReference = updateVariableName;
builder.addStatement("$T<$T> updater = $L.update($T.class)", ExecutableUpdate.class, Class<?> domainType = context.getRepositoryInformation().getDomainType();
context.getRepositoryInformation().getDomainType(), mongoOpsRef, builder.addStatement("$T<$T> $L = $L.update($T.class)", ExecutableUpdate.class, domainType,
context.getRepositoryInformation().getDomainType()); context.localVariable("updater"), mongoOpsRef, domainType);
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
if (ReflectionUtils.isVoid(returnType)) { if (ReflectionUtils.isVoid(returnType)) {
builder.addStatement("updater.matching($L).apply($L).all()", queryVariableName, updateReference); builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName,
updateReference);
} else if (ClassUtils.isAssignable(Long.class, returnType)) { } else if (ClassUtils.isAssignable(Long.class, returnType)) {
builder.addStatement("return updater.matching($L).apply($L).all().getModifiedCount()", queryVariableName, builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()",
context.localVariable("updater"), queryVariableName,
updateReference); updateReference);
} else { } else {
builder.addStatement("$T modifiedCount = updater.matching($L).apply($L).all().getModifiedCount()", Long.class, builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class,
context.localVariable("modifiedCount"), context.localVariable("updater"),
queryVariableName, updateReference); queryVariableName, updateReference);
builder.addStatement("return $T.convertNumberToTargetClass(modifiedCount, $T.class)", NumberUtils.class, builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class,
context.localVariable("modifiedCount"),
returnType); returnType);
} }
@ -314,24 +318,29 @@ class MongoCodeBlocks {
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
builder.addStatement("$T results = $L.aggregate($L, $T.class)", AggregationResults.class, mongoOpsRef, builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
context.localVariable("results"), mongoOpsRef,
aggregationVariableName, outputType); aggregationVariableName, outputType);
if (!queryMethod.isCollectionQuery()) { if (!queryMethod.isCollectionQuery()) {
builder.addStatement( builder.addStatement(
"return $T.<$T>firstElement(convertSimpleRawResults($T.class, results.getMappedResults()))", "return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))",
CollectionUtils.class, returnType, returnType); CollectionUtils.class, returnType, returnType, context.localVariable("results"));
} else { } else {
builder.addStatement("return convertSimpleRawResults($T.class, results.getMappedResults())", returnType); builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
context.localVariable("results"));
} }
} else { } else {
if (queryMethod.isSliceQuery()) { if (queryMethod.isSliceQuery()) {
builder.addStatement("$T results = $L.aggregate($L, $T.class)", AggregationResults.class, mongoOpsRef, builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
context.localVariable("results"), mongoOpsRef,
aggregationVariableName, outputType); aggregationVariableName, outputType);
builder.addStatement("boolean hasNext = results.getMappedResults().size() > $L.getPageSize()", builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()",
context.getPageableParameterName()); context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName());
builder.addStatement( builder.addStatement(
"return new $T<>(hasNext ? results.getMappedResults().subList(0, $L.getPageSize()) : results.getMappedResults(), $L, hasNext)", "return new $T<>($L ? $L.getMappedResults().subList(0, $L.getPageSize()) : $L.getMappedResults(), $L, $L)",
SliceImpl.class, context.getPageableParameterName(), context.getPageableParameterName()); SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"),
context.getPageableParameterName(), context.localVariable("results"), context.getPageableParameterName(),
context.localVariable("hasNext"));
} else { } else {
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
aggregationVariableName, outputType); aggregationVariableName, outputType);
@ -368,18 +377,19 @@ class MongoCodeBlocks {
Builder builder = CodeBlock.builder(); Builder builder = CodeBlock.builder();
boolean isProjecting = context.getReturnedType().isProjecting(); boolean isProjecting = context.getReturnedType().isProjecting();
Class<?> domainType = context.getRepositoryInformation().getDomainType();
Object actualReturnType = isProjecting ? context.getActualReturnType().getType() Object actualReturnType = isProjecting ? context.getActualReturnType().getType()
: context.getRepositoryInformation().getDomainType(); : domainType;
builder.add("\n"); builder.add("\n");
if (isProjecting) { if (isProjecting) {
builder.addStatement("$T<$T> finder = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType, builder.addStatement("$T<$T> $L = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType,
mongoOpsRef, context.getRepositoryInformation().getDomainType(), actualReturnType); context.localVariable("finder"), mongoOpsRef, domainType, actualReturnType);
} else { } else {
builder.addStatement("$T<$T> finder = $L.query($T.class)", FindWithQuery.class, actualReturnType, mongoOpsRef, builder.addStatement("$T<$T> $L = $L.query($T.class)", FindWithQuery.class, actualReturnType,
context.getRepositoryInformation().getDomainType()); context.localVariable("finder"), mongoOpsRef, domainType);
} }
String terminatingMethod; String terminatingMethod;
@ -395,13 +405,14 @@ class MongoCodeBlocks {
} }
if (queryMethod.isPageQuery()) { if (queryMethod.isPageQuery()) {
builder.addStatement("return new $T(finder, $L).execute($L)", PagedExecution.class, builder.addStatement("return new $T($L, $L).execute($L)", PagedExecution.class, context.localVariable("finder"),
context.getPageableParameterName(), query.name()); context.getPageableParameterName(), query.name());
} else if (queryMethod.isSliceQuery()) { } else if (queryMethod.isSliceQuery()) {
builder.addStatement("return new $T(finder, $L).execute($L)", SlicedExecution.class, builder.addStatement("return new $T($L, $L).execute($L)", SlicedExecution.class,
context.getPageableParameterName(), query.name()); context.localVariable("finder"), context.getPageableParameterName(), query.name());
} else { } else {
builder.addStatement("return finder.matching($L).$L", query.name(), terminatingMethod); builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
terminatingMethod);
} }
return builder.build(); return builder.build();
@ -415,7 +426,7 @@ class MongoCodeBlocks {
private final MongoQueryMethod queryMethod; private final MongoQueryMethod queryMethod;
private AggregationInteraction source; private AggregationInteraction source;
private List<String> arguments; private final List<String> arguments;
private String aggregationVariableName; private String aggregationVariableName;
private boolean pipelineOnly; private boolean pipelineOnly;
@ -449,7 +460,7 @@ class MongoCodeBlocks {
CodeBlock.Builder builder = CodeBlock.builder(); CodeBlock.Builder builder = CodeBlock.builder();
builder.add("\n"); builder.add("\n");
String pipelineName = aggregationVariableName + (pipelineOnly ? "" : "Pipeline"); String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline"));
builder.add(pipeline(pipelineName)); builder.add(pipeline(pipelineName));
if (!pipelineOnly) { if (!pipelineOnly) {
@ -486,8 +497,7 @@ class MongoCodeBlocks {
} }
Builder builder = CodeBlock.builder(); Builder builder = CodeBlock.builder();
String stagesVariableName = "stages"; builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments));
builder.add(aggregationStages(stagesVariableName, source.stages(), stageCount, arguments));
if (mightBeSorted) { if (mightBeSorted) {
builder.add(sortingStage(sortParameter)); builder.add(sortingStage(sortParameter));
@ -502,7 +512,7 @@ class MongoCodeBlocks {
} }
builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName, builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName,
stagesVariableName); context.localVariable("stages"));
return builder.build(); return builder.build();
} }
@ -533,7 +543,8 @@ class MongoCodeBlocks {
if (!options.isEmpty()) { if (!options.isEmpty()) {
Builder optionsBuilder = CodeBlock.builder(); Builder optionsBuilder = CodeBlock.builder();
optionsBuilder.add("$T aggregationOptions = $T.builder()\n", AggregationOptions.class, optionsBuilder.add("$T $L = $T.builder()\n", AggregationOptions.class,
context.localVariable("aggregationOptions"),
AggregationOptions.class); AggregationOptions.class);
optionsBuilder.indent(); optionsBuilder.indent();
for (CodeBlock optionBlock : options) { for (CodeBlock optionBlock : options) {
@ -544,67 +555,81 @@ class MongoCodeBlocks {
optionsBuilder.unindent(); optionsBuilder.unindent();
builder.add(optionsBuilder.build()); builder.add(optionsBuilder.build());
builder.addStatement("$L = $L.withOptions(aggregationOptions)", aggregationVariableName, builder.addStatement("$L = $L.withOptions($L)", aggregationVariableName, aggregationVariableName,
aggregationVariableName); context.localVariable("aggregationOptions"));
} }
return builder.build(); return builder.build();
} }
private static CodeBlock aggregationStages(String stageListVariableName, Iterable<String> stages, int stageCount, private CodeBlock aggregationStages(String stageListVariableName, Iterable<String> stages, int stageCount,
List<String> arguments) { List<String> arguments) {
Builder builder = CodeBlock.builder(); Builder builder = CodeBlock.builder();
builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class,
stageCount); stageCount);
int stageCounter = 0; int stageCounter = 0;
for (String stage : stages) { for (String stage : stages) {
String stageName = "stage_%s".formatted(stageCounter++); String stageName = context.localVariable("stage_%s".formatted(stageCounter++));
builder.add(renderExpressionToDocument(stage, stageName, arguments)); builder.add(renderExpressionToDocument(stage, stageName, arguments));
builder.addStatement("stages.add($L)", stageName); builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName);
} }
return builder.build(); return builder.build();
} }
private static CodeBlock sortingStage(String sortProvider) { private CodeBlock sortingStage(String sortProvider) {
Builder builder = CodeBlock.builder(); Builder builder = CodeBlock.builder();
builder.beginControlFlow("if($L.isSorted())", sortProvider);
builder.addStatement("$T sortDocument = new $T()", Document.class, Document.class); builder.beginControlFlow("if ($L.isSorted())", sortProvider);
builder.beginControlFlow("for ($T order : $L)", Order.class, sortProvider); builder.addStatement("$T $L = new $T()", Document.class, context.localVariable("sortDocument"), Document.class);
builder.addStatement("sortDocument.append(order.getProperty(), order.isAscending() ? 1 : -1);"); builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider);
builder.addStatement("$L.append($L.getProperty(), $L.isAscending() ? 1 : -1);",
context.localVariable("sortDocument"), context.localVariable("order"), context.localVariable("order"));
builder.endControlFlow(); builder.endControlFlow();
builder.addStatement("stages.add(new $T($S, sortDocument))", Document.class, "$sort"); builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort",
context.localVariable("sortDocument"));
builder.endControlFlow(); builder.endControlFlow();
return builder.build(); return builder.build();
} }
private static CodeBlock pagingStage(String pageableProvider, boolean slice) { private CodeBlock pagingStage(String pageableProvider, boolean slice) {
Builder builder = CodeBlock.builder(); Builder builder = CodeBlock.builder();
builder.add(sortingStage(pageableProvider + ".getSort()")); builder.add(sortingStage(pageableProvider + ".getSort()"));
builder.beginControlFlow("if($L.isPaged())", pageableProvider); builder.beginControlFlow("if ($L.isPaged())", pageableProvider);
builder.beginControlFlow("if($L.getOffset() > 0)", pageableProvider); builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider);
builder.addStatement("stages.add($T.skip($L.getOffset()))", Aggregation.class, pageableProvider); builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class,
pageableProvider);
builder.endControlFlow(); builder.endControlFlow();
if (slice) { if (slice) {
builder.addStatement("stages.add($T.limit($L.getPageSize() + 1))", Aggregation.class, pageableProvider); builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"),
Aggregation.class, pageableProvider);
} else { } else {
builder.addStatement("stages.add($T.limit($L.getPageSize()))", Aggregation.class, pageableProvider); builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class,
pageableProvider);
} }
builder.endControlFlow(); builder.endControlFlow();
return builder.build(); return builder.build();
} }
private static CodeBlock limitingStage(String limitProvider) { private CodeBlock limitingStage(String limitProvider) {
Builder builder = CodeBlock.builder(); Builder builder = CodeBlock.builder();
builder.beginControlFlow("if($L.isLimited())", limitProvider);
builder.addStatement("stages.add($T.limit($L.max()))", Aggregation.class, limitProvider); builder.beginControlFlow("if ($L.isLimited())", limitProvider);
builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class,
limitProvider);
builder.endControlFlow(); builder.endControlFlow();
return builder.build(); return builder.build();
} }
} }
@NullUnmarked @NullUnmarked
@ -614,7 +639,7 @@ class MongoCodeBlocks {
private final MongoQueryMethod queryMethod; private final MongoQueryMethod queryMethod;
private QueryInteraction source; private QueryInteraction source;
private List<String> arguments; private final List<String> arguments;
private String queryVariableName; private String queryVariableName;
QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
@ -697,17 +722,10 @@ class MongoCodeBlocks {
builder.addStatement("$T $L = new $T(new $T())", BasicQuery.class, variableName, BasicQuery.class, builder.addStatement("$T $L = new $T(new $T())", BasicQuery.class, variableName, BasicQuery.class,
Document.class); Document.class);
} else if (!containsPlaceholder(source)) { } else if (!containsPlaceholder(source)) {
builder.addStatement("$T $L = new $T($T.parse($S))", BasicQuery.class, variableName, BasicQuery.class,
String tmpVarName = "%sString".formatted(variableName); Document.class, source);
builder.addStatement("String $L = $S", tmpVarName, source);
builder.addStatement("$T $L = new $T($T.parse($L))", BasicQuery.class, variableName, BasicQuery.class,
Document.class, tmpVarName);
} else { } else {
builder.addStatement("$T $L = createQuery($S, new $T[]{ $L })", BasicQuery.class, variableName, source,
String tmpVarName = "%sString".formatted(variableName);
builder.addStatement("String $L = $S", tmpVarName, source);
builder.addStatement("$T $L = createQuery($L, new $T[]{ $L })", BasicQuery.class, variableName, tmpVarName,
Object.class, StringUtils.collectionToDelimitedString(arguments, ", ")); Object.class, StringUtils.collectionToDelimitedString(arguments, ", "));
} }
@ -757,15 +775,9 @@ class MongoCodeBlocks {
if (!StringUtils.hasText(source)) { if (!StringUtils.hasText(source)) {
builder.addStatement("$T $L = new $T()", Document.class, variableName, Document.class); builder.addStatement("$T $L = new $T()", Document.class, variableName, Document.class);
} else if (!containsPlaceholder(source)) { } else if (!containsPlaceholder(source)) {
builder.addStatement("$T $L = $T.parse($S)", Document.class, variableName, Document.class, source);
String tmpVarName = "%sString".formatted(variableName);
builder.addStatement("String $L = $S", tmpVarName, source);
builder.addStatement("$T $L = $T.parse($L)", Document.class, variableName, Document.class, tmpVarName);
} else { } else {
builder.addStatement("$T $L = bindParameters($S, new $T[]{ $L })", Document.class, variableName, source,
String tmpVarName = "%sString".formatted(variableName);
builder.addStatement("String $L = $S", tmpVarName, source);
builder.addStatement("$T $L = bindParameters($L, new $T[]{ $L })", Document.class, variableName, tmpVarName,
Object.class, StringUtils.collectionToDelimitedString(arguments, ", ")); Object.class, StringUtils.collectionToDelimitedString(arguments, ", "));
} }
return builder.build(); return builder.build();

120
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

@ -15,13 +15,7 @@
*/ */
package org.springframework.data.mongodb.repository.aot; package org.springframework.data.mongodb.repository.aot;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationBlockBuilder; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationExecutionBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.deleteExecutionBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryExecutionBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateExecutionBlockBuilder;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.regex.Pattern; import java.util.regex.Pattern;
@ -29,16 +23,16 @@ import java.util.regex.Pattern;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate; import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.Query;
import org.springframework.data.mongodb.repository.Update; import org.springframework.data.mongodb.repository.Update;
import org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.QueryCodeBlockBuilder;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotRepositoryClassBuilder;
import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder; import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder;
import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata;
import org.springframework.data.repository.aot.generate.MethodContributor; import org.springframework.data.repository.aot.generate.MethodContributor;
import org.springframework.data.repository.aot.generate.RepositoryContributor; import org.springframework.data.repository.aot.generate.RepositoryContributor;
import org.springframework.data.repository.config.AotRepositoryContext; import org.springframework.data.repository.config.AotRepositoryContext;
@ -48,7 +42,6 @@ import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.parser.PartTree; import org.springframework.data.repository.query.parser.PartTree;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.TypeName; import org.springframework.javapoet.TypeName;
import org.springframework.javapoet.TypeSpec.Builder;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -60,7 +53,7 @@ import org.springframework.util.StringUtils;
*/ */
public class MongoRepositoryContributor extends RepositoryContributor { public class MongoRepositoryContributor extends RepositoryContributor {
private static final Log logger = LogFactory.getLog(RepositoryContributor.class); private static final Log logger = LogFactory.getLog(MongoRepositoryContributor.class);
private final AotQueryCreator queryCreator; private final AotQueryCreator queryCreator;
private final MongoMappingContext mappingContext; private final MongoMappingContext mappingContext;
@ -73,9 +66,8 @@ public class MongoRepositoryContributor extends RepositoryContributor {
} }
@Override @Override
protected void customizeClass(RepositoryInformation information, AotRepositoryFragmentMetadata metadata, protected void customizeClass(AotRepositoryClassBuilder classBuilder) {
Builder builder) { classBuilder.customize(builder -> builder.superclass(TypeName.get(MongoAotRepositoryFragmentSupport.class)));
builder.superclass(TypeName.get(MongoAotRepositoryFragmentSupport.class));
} }
@Override @Override
@ -85,70 +77,60 @@ public class MongoRepositoryContributor extends RepositoryContributor {
constructorBuilder.addParameter("context", TypeName.get(RepositoryFactoryBeanSupport.FragmentCreationContext.class), constructorBuilder.addParameter("context", TypeName.get(RepositoryFactoryBeanSupport.FragmentCreationContext.class),
false); false);
constructorBuilder.customize((repositoryInformation, builder) -> { constructorBuilder.customize((builder) -> {
builder.addStatement("super(operations, context)"); builder.addStatement("super(operations, context)");
}); });
} }
@Override @Override
@SuppressWarnings("NullAway") @SuppressWarnings("NullAway")
protected @Nullable MethodContributor<? extends QueryMethod> contributeQueryMethod(Method method, protected @Nullable MethodContributor<? extends QueryMethod> contributeQueryMethod(Method method) {
RepositoryInformation repositoryInformation) {
MongoQueryMethod queryMethod = new MongoQueryMethod(method, repositoryInformation, getProjectionFactory(), MongoQueryMethod queryMethod = new MongoQueryMethod(method, getRepositoryInformation(), getProjectionFactory(),
mappingContext); mappingContext);
if (backoff(queryMethod)) { if (queryMethod.hasAnnotatedAggregation()) {
return null; AggregationInteraction aggregation = new AggregationInteraction(queryMethod.getAnnotatedAggregation());
return aggregationMethodContributor(queryMethod, aggregation);
} }
try { QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod,
if (queryMethod.hasAnnotatedAggregation()) { AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount());
AggregationInteraction aggregation = new AggregationInteraction(queryMethod.getAnnotatedAggregation());
return aggregationMethodContributor(queryMethod, aggregation);
}
QueryInteraction query = createStringQuery(repositoryInformation, queryMethod,
AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount());
if (queryMethod.hasAnnotatedQuery()) { if (queryMethod.hasAnnotatedQuery()) {
if (StringUtils.hasText(queryMethod.getAnnotatedQuery()) if (StringUtils.hasText(queryMethod.getAnnotatedQuery())
&& Pattern.compile("[\\?:][#$]\\{.*\\}").matcher(queryMethod.getAnnotatedQuery()).find()) { && Pattern.compile("[\\?:][#$]\\{.*\\}").matcher(queryMethod.getAnnotatedQuery()).find()) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug( logger.debug(
"Skipping AOT generation for [%s]. SpEL expressions are not supported".formatted(method.getName())); "Skipping AOT generation for [%s]. SpEL expressions are not supported".formatted(method.getName()));
}
return MethodContributor.forQueryMethod(queryMethod).metadataOnly(query);
} }
return MethodContributor.forQueryMethod(queryMethod).metadataOnly(query);
} }
}
if (query.isDelete()) { if (backoff(queryMethod)) {
return deleteMethodContributor(queryMethod, query); return null;
} }
if (queryMethod.isModifyingQuery()) { if (query.isDelete()) {
return deleteMethodContributor(queryMethod, query);
}
Update updateSource = queryMethod.getUpdateSource(); if (queryMethod.isModifyingQuery()) {
if (StringUtils.hasText(updateSource.value())) {
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()));
return updateMethodContributor(queryMethod, update);
}
if (!ObjectUtils.isEmpty(updateSource.pipeline())) {
AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline());
return aggregationUpdateMethodContributor(queryMethod, update);
}
}
return queryMethodContributor(queryMethod, query); Update updateSource = queryMethod.getUpdateSource();
} catch (RuntimeException codeGenerationError) { if (StringUtils.hasText(updateSource.value())) {
if (logger.isErrorEnabled()) { UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()));
logger.error("Failed to generate code for [%s] [%s]".formatted(repositoryInformation.getRepositoryInterface(), return updateMethodContributor(queryMethod, update);
method.getName()), codeGenerationError); }
if (!ObjectUtils.isEmpty(updateSource.pipeline())) {
AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline());
return aggregationUpdateMethodContributor(queryMethod, update);
} }
} }
return null;
return queryMethodContributor(queryMethod, query);
} }
@SuppressWarnings("NullAway") @SuppressWarnings("NullAway")
@ -193,7 +175,6 @@ public class MongoRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(aggregation).contribute(context -> { return MethodContributor.forQueryMethod(queryMethod).withMetadata(aggregation).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder(); CodeBlock.Builder builder = CodeBlock.builder();
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
builder.add(aggregationBlockBuilder(context, queryMethod).stages(aggregation) builder.add(aggregationBlockBuilder(context, queryMethod).stages(aggregation)
.usingAggregationVariableName("aggregation").build()); .usingAggregationVariableName("aggregation").build());
@ -209,15 +190,14 @@ public class MongoRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> { return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder(); CodeBlock.Builder builder = CodeBlock.builder();
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
// update filter // update filter
String filterVariableName = update.name(); String filterVariableName = context.localVariable(update.name());
builder.add(queryBlockBuilder(context, queryMethod).filter(update.getFilter()) builder.add(queryBlockBuilder(context, queryMethod).filter(update.getFilter())
.usingQueryVariableName(filterVariableName).build()); .usingQueryVariableName(filterVariableName).build());
// update definition // update definition
String updateVariableName = "updateDefinition"; String updateVariableName = context.localVariable("updateDefinition");
builder.add( builder.add(
updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName).build()); updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName).build());
@ -233,10 +213,9 @@ public class MongoRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> { return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder(); CodeBlock.Builder builder = CodeBlock.builder();
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
// update filter // update filter
String filterVariableName = update.name(); String filterVariableName = context.localVariable(update.name());
QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(update.getFilter()); QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(update.getFilter());
builder.add(queryCodeBlockBuilder.usingQueryVariableName(filterVariableName).build()); builder.add(queryCodeBlockBuilder.usingQueryVariableName(filterVariableName).build());
@ -245,11 +224,12 @@ public class MongoRepositoryContributor extends RepositoryContributor {
builder.add(aggregationBlockBuilder(context, queryMethod).stages(update) builder.add(aggregationBlockBuilder(context, queryMethod).stages(update)
.usingAggregationVariableName(updateVariableName).pipelineOnly(true).build()); .usingAggregationVariableName(updateVariableName).pipelineOnly(true).build());
builder.addStatement("$T aggregationUpdate = $T.from($L.getOperations())", AggregationUpdate.class, builder.addStatement("$T $L = $T.from($L.getOperations())", AggregationUpdate.class,
context.localVariable("aggregationUpdate"),
AggregationUpdate.class, updateVariableName); AggregationUpdate.class, updateVariableName);
builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName) builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName)
.referencingUpdate("aggregationUpdate").build()); .referencingUpdate(context.localVariable("aggregationUpdate")).build());
return builder.build(); return builder.build();
}); });
} }
@ -260,11 +240,11 @@ public class MongoRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> { return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder(); CodeBlock.Builder builder = CodeBlock.builder();
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(query); QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(query);
builder.add(queryCodeBlockBuilder.usingQueryVariableName(query.name()).build()); String queryVariableName = context.localVariable(query.name());
builder.add(deleteExecutionBlockBuilder(context, queryMethod).referencing(query.name()).build()); builder.add(queryCodeBlockBuilder.usingQueryVariableName(queryVariableName).build());
builder.add(deleteExecutionBlockBuilder(context, queryMethod).referencing(queryVariableName).build());
return builder.build(); return builder.build();
}); });
} }
@ -275,12 +255,12 @@ public class MongoRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> { return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder(); CodeBlock.Builder builder = CodeBlock.builder();
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(query); QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(query);
builder.add(queryCodeBlockBuilder.usingQueryVariableName(query.name()).build()); builder.add(queryCodeBlockBuilder.usingQueryVariableName(context.localVariable(query.name())).build());
builder.add(queryExecutionBlockBuilder(context, queryMethod).forQuery(query).build()); builder.add(queryExecutionBlockBuilder(context, queryMethod).forQuery(query).build());
return builder.build(); return builder.build();
}); });
} }
} }

5
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java

@ -51,11 +51,14 @@ import org.springframework.util.StringUtils;
import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClient;
/** /**
* Integration tests for the {@link UserRepository} AOT fragment.
*
* @author Christoph Strobl * @author Christoph Strobl
* @author Mark Paluch
*/ */
@ExtendWith(MongoClientExtension.class) @ExtendWith(MongoClientExtension.class)
@SpringJUnitConfig(classes = MongoRepositoryContributorTests.MongoRepositoryContributorConfiguration.class) @SpringJUnitConfig(classes = MongoRepositoryContributorTests.MongoRepositoryContributorConfiguration.class)
public class MongoRepositoryContributorTests { class MongoRepositoryContributorTests {
private static final String DB_NAME = "aot-repo-tests"; private static final String DB_NAME = "aot-repo-tests";

9
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java

@ -15,9 +15,9 @@
*/ */
package org.springframework.data.mongodb.repository.aot; package org.springframework.data.mongodb.repository.aot;
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.*;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.*;
import example.aot.UserRepository; import example.aot.UserRepository;
@ -27,6 +27,7 @@ import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
@ -176,7 +177,7 @@ class MongoRepositoryMetadataTests {
assertThat(resource.exists()).isTrue(); assertThat(resource.exists()).isTrue();
String json = resource.getContentAsString(StandardCharsets.UTF_8); String json = resource.getContentAsString(StandardCharsets.UTF_8);
System.out.println(json);
assertThatJson(json).inPath("$.methods[?(@.name == 'existsById')].fragment").isArray().first().isObject() assertThatJson(json).inPath("$.methods[?(@.name == 'existsById')].fragment").isArray().first().isObject()
.containsEntry("fragment", "org.springframework.data.mongodb.repository.support.SimpleMongoRepository"); .containsEntry("fragment", "org.springframework.data.mongodb.repository.support.SimpleMongoRepository");
} }

150
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java

@ -38,81 +38,81 @@ import org.springframework.data.repository.core.support.RepositoryComposition;
*/ */
class TestMongoAotRepositoryContext implements AotRepositoryContext { class TestMongoAotRepositoryContext implements AotRepositoryContext {
private final StubRepositoryInformation repositoryInformation; private final StubRepositoryInformation repositoryInformation;
private final Environment environment = new StandardEnvironment(); private final Environment environment = new StandardEnvironment();
TestMongoAotRepositoryContext(Class<?> repositoryInterface, @Nullable RepositoryComposition composition) { TestMongoAotRepositoryContext(Class<?> repositoryInterface, @Nullable RepositoryComposition composition) {
this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition); this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition);
} }
@Override @Override
public ConfigurableListableBeanFactory getBeanFactory() { public ConfigurableListableBeanFactory getBeanFactory() {
return null; return null;
} }
@Override @Override
public TypeIntrospector introspectType(String typeName) { public TypeIntrospector introspectType(String typeName) {
return null; return null;
} }
@Override @Override
public IntrospectedBeanDefinition introspectBeanDefinition(String beanName) { public IntrospectedBeanDefinition introspectBeanDefinition(String beanName) {
return null; return null;
} }
@Override @Override
public String getBeanName() { public String getBeanName() {
return "dummyRepository"; return "dummyRepository";
} }
@Override @Override
public String getModuleName() { public String getModuleName() {
return "MongoDB"; return "MongoDB";
}
@Override
public Set<String> getBasePackages() {
return Set.of("org.springframework.data.dummy.repository.aot");
}
@Override
public Set<Class<? extends Annotation>> getIdentifyingAnnotations() {
return Set.of(Document.class);
}
@Override
public RepositoryInformation getRepositoryInformation() {
return repositoryInformation;
}
@Override
public Set<MergedAnnotation<Annotation>> getResolvedAnnotations() {
return Set.of();
}
@Override
public Set<Class<?>> getResolvedTypes() {
return Set.of();
}
public List<ClassFile> getRequiredContextFiles() {
return List.of(classFileForType(repositoryInformation.getRepositoryBaseClass()));
}
static ClassFile classFileForType(Class<?> type) {
String name = type.getName();
ClassPathResource cpr = new ClassPathResource(name.replaceAll("\\.", "/") + ".class");
try {
return ClassFile.of(name, cpr.getContentAsByteArray());
} catch (IOException e) {
throw new IllegalArgumentException("Cannot open [%s].".formatted(cpr.getPath()));
} }
}
@Override @Override
public Set<String> getBasePackages() { public Environment getEnvironment() {
return Set.of("org.springframework.data.dummy.repository.aot"); return environment;
} }
@Override
public Set<Class<? extends Annotation>> getIdentifyingAnnotations() {
return Set.of(Document.class);
}
@Override
public RepositoryInformation getRepositoryInformation() {
return repositoryInformation;
}
@Override
public Set<MergedAnnotation<Annotation>> getResolvedAnnotations() {
return Set.of();
}
@Override
public Set<Class<?>> getResolvedTypes() {
return Set.of();
}
public List<ClassFile> getRequiredContextFiles() {
return List.of(classFileForType(repositoryInformation.getRepositoryBaseClass()));
}
static ClassFile classFileForType(Class<?> type) {
String name = type.getName();
ClassPathResource cpr = new ClassPathResource(name.replaceAll("\\.", "/") + ".class");
try {
return ClassFile.of(name, cpr.getContentAsByteArray());
} catch (IOException e) {
throw new IllegalArgumentException("Cannot open [%s].".formatted(cpr.getPath()));
}
}
@Override
public Environment getEnvironment() {
return environment;
}
} }

Loading…
Cancel
Save