Browse Source

Use `LocalVariableNameFactory` to avoid parameter name clashes.

Closes #4965
pull/4976/head
Mark Paluch 7 months ago
parent
commit
d850b9ef06
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 3
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java
  2. 152
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java
  3. 120
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java
  4. 5
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java
  5. 9
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java
  6. 150
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java

3
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java

@ -35,7 +35,10 @@ import org.springframework.util.ClassUtils; @@ -35,7 +35,10 @@ import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
/**
* Support class for MongoDB AOT repository fragments.
*
* @author Christoph Strobl
* @since 5.0
*/
public class MongoAotRepositoryFragmentSupport {

152
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

@ -246,20 +246,24 @@ class MongoCodeBlocks { @@ -246,20 +246,24 @@ class MongoCodeBlocks {
builder.add("\n");
String updateReference = updateVariableName;
builder.addStatement("$T<$T> updater = $L.update($T.class)", ExecutableUpdate.class,
context.getRepositoryInformation().getDomainType(), mongoOpsRef,
context.getRepositoryInformation().getDomainType());
Class<?> domainType = context.getRepositoryInformation().getDomainType();
builder.addStatement("$T<$T> $L = $L.update($T.class)", ExecutableUpdate.class, domainType,
context.localVariable("updater"), mongoOpsRef, domainType);
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
if (ReflectionUtils.isVoid(returnType)) {
builder.addStatement("updater.matching($L).apply($L).all()", queryVariableName, updateReference);
builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName,
updateReference);
} else if (ClassUtils.isAssignable(Long.class, returnType)) {
builder.addStatement("return updater.matching($L).apply($L).all().getModifiedCount()", queryVariableName,
builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()",
context.localVariable("updater"), queryVariableName,
updateReference);
} else {
builder.addStatement("$T modifiedCount = updater.matching($L).apply($L).all().getModifiedCount()", Long.class,
builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class,
context.localVariable("modifiedCount"), context.localVariable("updater"),
queryVariableName, updateReference);
builder.addStatement("return $T.convertNumberToTargetClass(modifiedCount, $T.class)", NumberUtils.class,
builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class,
context.localVariable("modifiedCount"),
returnType);
}
@ -314,24 +318,29 @@ class MongoCodeBlocks { @@ -314,24 +318,29 @@ class MongoCodeBlocks {
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
builder.addStatement("$T results = $L.aggregate($L, $T.class)", AggregationResults.class, mongoOpsRef,
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
context.localVariable("results"), mongoOpsRef,
aggregationVariableName, outputType);
if (!queryMethod.isCollectionQuery()) {
builder.addStatement(
"return $T.<$T>firstElement(convertSimpleRawResults($T.class, results.getMappedResults()))",
CollectionUtils.class, returnType, returnType);
"return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))",
CollectionUtils.class, returnType, returnType, context.localVariable("results"));
} else {
builder.addStatement("return convertSimpleRawResults($T.class, results.getMappedResults())", returnType);
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
context.localVariable("results"));
}
} else {
if (queryMethod.isSliceQuery()) {
builder.addStatement("$T results = $L.aggregate($L, $T.class)", AggregationResults.class, mongoOpsRef,
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
context.localVariable("results"), mongoOpsRef,
aggregationVariableName, outputType);
builder.addStatement("boolean hasNext = results.getMappedResults().size() > $L.getPageSize()",
context.getPageableParameterName());
builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()",
context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName());
builder.addStatement(
"return new $T<>(hasNext ? results.getMappedResults().subList(0, $L.getPageSize()) : results.getMappedResults(), $L, hasNext)",
SliceImpl.class, context.getPageableParameterName(), context.getPageableParameterName());
"return new $T<>($L ? $L.getMappedResults().subList(0, $L.getPageSize()) : $L.getMappedResults(), $L, $L)",
SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"),
context.getPageableParameterName(), context.localVariable("results"), context.getPageableParameterName(),
context.localVariable("hasNext"));
} else {
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
aggregationVariableName, outputType);
@ -368,18 +377,19 @@ class MongoCodeBlocks { @@ -368,18 +377,19 @@ class MongoCodeBlocks {
Builder builder = CodeBlock.builder();
boolean isProjecting = context.getReturnedType().isProjecting();
Class<?> domainType = context.getRepositoryInformation().getDomainType();
Object actualReturnType = isProjecting ? context.getActualReturnType().getType()
: context.getRepositoryInformation().getDomainType();
: domainType;
builder.add("\n");
if (isProjecting) {
builder.addStatement("$T<$T> finder = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType,
mongoOpsRef, context.getRepositoryInformation().getDomainType(), actualReturnType);
builder.addStatement("$T<$T> $L = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType,
context.localVariable("finder"), mongoOpsRef, domainType, actualReturnType);
} else {
builder.addStatement("$T<$T> finder = $L.query($T.class)", FindWithQuery.class, actualReturnType, mongoOpsRef,
context.getRepositoryInformation().getDomainType());
builder.addStatement("$T<$T> $L = $L.query($T.class)", FindWithQuery.class, actualReturnType,
context.localVariable("finder"), mongoOpsRef, domainType);
}
String terminatingMethod;
@ -395,13 +405,14 @@ class MongoCodeBlocks { @@ -395,13 +405,14 @@ class MongoCodeBlocks {
}
if (queryMethod.isPageQuery()) {
builder.addStatement("return new $T(finder, $L).execute($L)", PagedExecution.class,
builder.addStatement("return new $T($L, $L).execute($L)", PagedExecution.class, context.localVariable("finder"),
context.getPageableParameterName(), query.name());
} else if (queryMethod.isSliceQuery()) {
builder.addStatement("return new $T(finder, $L).execute($L)", SlicedExecution.class,
context.getPageableParameterName(), query.name());
builder.addStatement("return new $T($L, $L).execute($L)", SlicedExecution.class,
context.localVariable("finder"), context.getPageableParameterName(), query.name());
} else {
builder.addStatement("return finder.matching($L).$L", query.name(), terminatingMethod);
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
terminatingMethod);
}
return builder.build();
@ -415,7 +426,7 @@ class MongoCodeBlocks { @@ -415,7 +426,7 @@ class MongoCodeBlocks {
private final MongoQueryMethod queryMethod;
private AggregationInteraction source;
private List<String> arguments;
private final List<String> arguments;
private String aggregationVariableName;
private boolean pipelineOnly;
@ -449,7 +460,7 @@ class MongoCodeBlocks { @@ -449,7 +460,7 @@ class MongoCodeBlocks {
CodeBlock.Builder builder = CodeBlock.builder();
builder.add("\n");
String pipelineName = aggregationVariableName + (pipelineOnly ? "" : "Pipeline");
String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline"));
builder.add(pipeline(pipelineName));
if (!pipelineOnly) {
@ -486,8 +497,7 @@ class MongoCodeBlocks { @@ -486,8 +497,7 @@ class MongoCodeBlocks {
}
Builder builder = CodeBlock.builder();
String stagesVariableName = "stages";
builder.add(aggregationStages(stagesVariableName, source.stages(), stageCount, arguments));
builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments));
if (mightBeSorted) {
builder.add(sortingStage(sortParameter));
@ -502,7 +512,7 @@ class MongoCodeBlocks { @@ -502,7 +512,7 @@ class MongoCodeBlocks {
}
builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName,
stagesVariableName);
context.localVariable("stages"));
return builder.build();
}
@ -533,7 +543,8 @@ class MongoCodeBlocks { @@ -533,7 +543,8 @@ class MongoCodeBlocks {
if (!options.isEmpty()) {
Builder optionsBuilder = CodeBlock.builder();
optionsBuilder.add("$T aggregationOptions = $T.builder()\n", AggregationOptions.class,
optionsBuilder.add("$T $L = $T.builder()\n", AggregationOptions.class,
context.localVariable("aggregationOptions"),
AggregationOptions.class);
optionsBuilder.indent();
for (CodeBlock optionBlock : options) {
@ -544,67 +555,81 @@ class MongoCodeBlocks { @@ -544,67 +555,81 @@ class MongoCodeBlocks {
optionsBuilder.unindent();
builder.add(optionsBuilder.build());
builder.addStatement("$L = $L.withOptions(aggregationOptions)", aggregationVariableName,
aggregationVariableName);
builder.addStatement("$L = $L.withOptions($L)", aggregationVariableName, aggregationVariableName,
context.localVariable("aggregationOptions"));
}
return builder.build();
}
private static CodeBlock aggregationStages(String stageListVariableName, Iterable<String> stages, int stageCount,
private CodeBlock aggregationStages(String stageListVariableName, Iterable<String> stages, int stageCount,
List<String> arguments) {
Builder builder = CodeBlock.builder();
builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class,
stageCount);
int stageCounter = 0;
for (String stage : stages) {
String stageName = "stage_%s".formatted(stageCounter++);
String stageName = context.localVariable("stage_%s".formatted(stageCounter++));
builder.add(renderExpressionToDocument(stage, stageName, arguments));
builder.addStatement("stages.add($L)", stageName);
builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName);
}
return builder.build();
}
private static CodeBlock sortingStage(String sortProvider) {
private CodeBlock sortingStage(String sortProvider) {
Builder builder = CodeBlock.builder();
builder.beginControlFlow("if($L.isSorted())", sortProvider);
builder.addStatement("$T sortDocument = new $T()", Document.class, Document.class);
builder.beginControlFlow("for ($T order : $L)", Order.class, sortProvider);
builder.addStatement("sortDocument.append(order.getProperty(), order.isAscending() ? 1 : -1);");
builder.beginControlFlow("if ($L.isSorted())", sortProvider);
builder.addStatement("$T $L = new $T()", Document.class, context.localVariable("sortDocument"), Document.class);
builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider);
builder.addStatement("$L.append($L.getProperty(), $L.isAscending() ? 1 : -1);",
context.localVariable("sortDocument"), context.localVariable("order"), context.localVariable("order"));
builder.endControlFlow();
builder.addStatement("stages.add(new $T($S, sortDocument))", Document.class, "$sort");
builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort",
context.localVariable("sortDocument"));
builder.endControlFlow();
return builder.build();
}
private static CodeBlock pagingStage(String pageableProvider, boolean slice) {
private CodeBlock pagingStage(String pageableProvider, boolean slice) {
Builder builder = CodeBlock.builder();
builder.add(sortingStage(pageableProvider + ".getSort()"));
builder.beginControlFlow("if($L.isPaged())", pageableProvider);
builder.beginControlFlow("if($L.getOffset() > 0)", pageableProvider);
builder.addStatement("stages.add($T.skip($L.getOffset()))", Aggregation.class, pageableProvider);
builder.beginControlFlow("if ($L.isPaged())", pageableProvider);
builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider);
builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class,
pageableProvider);
builder.endControlFlow();
if (slice) {
builder.addStatement("stages.add($T.limit($L.getPageSize() + 1))", Aggregation.class, pageableProvider);
builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"),
Aggregation.class, pageableProvider);
} else {
builder.addStatement("stages.add($T.limit($L.getPageSize()))", Aggregation.class, pageableProvider);
builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class,
pageableProvider);
}
builder.endControlFlow();
return builder.build();
}
private static CodeBlock limitingStage(String limitProvider) {
private CodeBlock limitingStage(String limitProvider) {
Builder builder = CodeBlock.builder();
builder.beginControlFlow("if($L.isLimited())", limitProvider);
builder.addStatement("stages.add($T.limit($L.max()))", Aggregation.class, limitProvider);
builder.beginControlFlow("if ($L.isLimited())", limitProvider);
builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class,
limitProvider);
builder.endControlFlow();
return builder.build();
}
}
@NullUnmarked
@ -614,7 +639,7 @@ class MongoCodeBlocks { @@ -614,7 +639,7 @@ class MongoCodeBlocks {
private final MongoQueryMethod queryMethod;
private QueryInteraction source;
private List<String> arguments;
private final List<String> arguments;
private String queryVariableName;
QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
@ -697,17 +722,10 @@ class MongoCodeBlocks { @@ -697,17 +722,10 @@ class MongoCodeBlocks {
builder.addStatement("$T $L = new $T(new $T())", BasicQuery.class, variableName, BasicQuery.class,
Document.class);
} else if (!containsPlaceholder(source)) {
String tmpVarName = "%sString".formatted(variableName);
builder.addStatement("String $L = $S", tmpVarName, source);
builder.addStatement("$T $L = new $T($T.parse($L))", BasicQuery.class, variableName, BasicQuery.class,
Document.class, tmpVarName);
builder.addStatement("$T $L = new $T($T.parse($S))", BasicQuery.class, variableName, BasicQuery.class,
Document.class, source);
} else {
String tmpVarName = "%sString".formatted(variableName);
builder.addStatement("String $L = $S", tmpVarName, source);
builder.addStatement("$T $L = createQuery($L, new $T[]{ $L })", BasicQuery.class, variableName, tmpVarName,
builder.addStatement("$T $L = createQuery($S, new $T[]{ $L })", BasicQuery.class, variableName, source,
Object.class, StringUtils.collectionToDelimitedString(arguments, ", "));
}
@ -757,15 +775,9 @@ class MongoCodeBlocks { @@ -757,15 +775,9 @@ class MongoCodeBlocks {
if (!StringUtils.hasText(source)) {
builder.addStatement("$T $L = new $T()", Document.class, variableName, Document.class);
} else if (!containsPlaceholder(source)) {
String tmpVarName = "%sString".formatted(variableName);
builder.addStatement("String $L = $S", tmpVarName, source);
builder.addStatement("$T $L = $T.parse($L)", Document.class, variableName, Document.class, tmpVarName);
builder.addStatement("$T $L = $T.parse($S)", Document.class, variableName, Document.class, source);
} else {
String tmpVarName = "%sString".formatted(variableName);
builder.addStatement("String $L = $S", tmpVarName, source);
builder.addStatement("$T $L = bindParameters($L, new $T[]{ $L })", Document.class, variableName, tmpVarName,
builder.addStatement("$T $L = bindParameters($S, new $T[]{ $L })", Document.class, variableName, source,
Object.class, StringUtils.collectionToDelimitedString(arguments, ", "));
}
return builder.build();

120
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

@ -15,13 +15,7 @@ @@ -15,13 +15,7 @@
*/
package org.springframework.data.mongodb.repository.aot;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationExecutionBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.deleteExecutionBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryExecutionBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateExecutionBlockBuilder;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*;
import java.lang.reflect.Method;
import java.util.regex.Pattern;
@ -29,16 +23,16 @@ import java.util.regex.Pattern; @@ -29,16 +23,16 @@ import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.repository.Query;
import org.springframework.data.mongodb.repository.Update;
import org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.QueryCodeBlockBuilder;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotRepositoryClassBuilder;
import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder;
import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata;
import org.springframework.data.repository.aot.generate.MethodContributor;
import org.springframework.data.repository.aot.generate.RepositoryContributor;
import org.springframework.data.repository.config.AotRepositoryContext;
@ -48,7 +42,6 @@ import org.springframework.data.repository.query.QueryMethod; @@ -48,7 +42,6 @@ import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.parser.PartTree;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.TypeName;
import org.springframework.javapoet.TypeSpec.Builder;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
@ -60,7 +53,7 @@ import org.springframework.util.StringUtils; @@ -60,7 +53,7 @@ import org.springframework.util.StringUtils;
*/
public class MongoRepositoryContributor extends RepositoryContributor {
private static final Log logger = LogFactory.getLog(RepositoryContributor.class);
private static final Log logger = LogFactory.getLog(MongoRepositoryContributor.class);
private final AotQueryCreator queryCreator;
private final MongoMappingContext mappingContext;
@ -73,9 +66,8 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -73,9 +66,8 @@ public class MongoRepositoryContributor extends RepositoryContributor {
}
@Override
protected void customizeClass(RepositoryInformation information, AotRepositoryFragmentMetadata metadata,
Builder builder) {
builder.superclass(TypeName.get(MongoAotRepositoryFragmentSupport.class));
protected void customizeClass(AotRepositoryClassBuilder classBuilder) {
classBuilder.customize(builder -> builder.superclass(TypeName.get(MongoAotRepositoryFragmentSupport.class)));
}
@Override
@ -85,70 +77,60 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -85,70 +77,60 @@ public class MongoRepositoryContributor extends RepositoryContributor {
constructorBuilder.addParameter("context", TypeName.get(RepositoryFactoryBeanSupport.FragmentCreationContext.class),
false);
constructorBuilder.customize((repositoryInformation, builder) -> {
constructorBuilder.customize((builder) -> {
builder.addStatement("super(operations, context)");
});
}
@Override
@SuppressWarnings("NullAway")
protected @Nullable MethodContributor<? extends QueryMethod> contributeQueryMethod(Method method,
RepositoryInformation repositoryInformation) {
protected @Nullable MethodContributor<? extends QueryMethod> contributeQueryMethod(Method method) {
MongoQueryMethod queryMethod = new MongoQueryMethod(method, repositoryInformation, getProjectionFactory(),
MongoQueryMethod queryMethod = new MongoQueryMethod(method, getRepositoryInformation(), getProjectionFactory(),
mappingContext);
if (backoff(queryMethod)) {
return null;
if (queryMethod.hasAnnotatedAggregation()) {
AggregationInteraction aggregation = new AggregationInteraction(queryMethod.getAnnotatedAggregation());
return aggregationMethodContributor(queryMethod, aggregation);
}
try {
if (queryMethod.hasAnnotatedAggregation()) {
AggregationInteraction aggregation = new AggregationInteraction(queryMethod.getAnnotatedAggregation());
return aggregationMethodContributor(queryMethod, aggregation);
}
QueryInteraction query = createStringQuery(repositoryInformation, queryMethod,
AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount());
QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod,
AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount());
if (queryMethod.hasAnnotatedQuery()) {
if (StringUtils.hasText(queryMethod.getAnnotatedQuery())
&& Pattern.compile("[\\?:][#$]\\{.*\\}").matcher(queryMethod.getAnnotatedQuery()).find()) {
if (queryMethod.hasAnnotatedQuery()) {
if (StringUtils.hasText(queryMethod.getAnnotatedQuery())
&& Pattern.compile("[\\?:][#$]\\{.*\\}").matcher(queryMethod.getAnnotatedQuery()).find()) {
if (logger.isDebugEnabled()) {
logger.debug(
"Skipping AOT generation for [%s]. SpEL expressions are not supported".formatted(method.getName()));
}
return MethodContributor.forQueryMethod(queryMethod).metadataOnly(query);
if (logger.isDebugEnabled()) {
logger.debug(
"Skipping AOT generation for [%s]. SpEL expressions are not supported".formatted(method.getName()));
}
return MethodContributor.forQueryMethod(queryMethod).metadataOnly(query);
}
}
if (query.isDelete()) {
return deleteMethodContributor(queryMethod, query);
}
if (backoff(queryMethod)) {
return null;
}
if (queryMethod.isModifyingQuery()) {
if (query.isDelete()) {
return deleteMethodContributor(queryMethod, query);
}
Update updateSource = queryMethod.getUpdateSource();
if (StringUtils.hasText(updateSource.value())) {
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()));
return updateMethodContributor(queryMethod, update);
}
if (!ObjectUtils.isEmpty(updateSource.pipeline())) {
AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline());
return aggregationUpdateMethodContributor(queryMethod, update);
}
}
if (queryMethod.isModifyingQuery()) {
return queryMethodContributor(queryMethod, query);
} catch (RuntimeException codeGenerationError) {
if (logger.isErrorEnabled()) {
logger.error("Failed to generate code for [%s] [%s]".formatted(repositoryInformation.getRepositoryInterface(),
method.getName()), codeGenerationError);
Update updateSource = queryMethod.getUpdateSource();
if (StringUtils.hasText(updateSource.value())) {
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()));
return updateMethodContributor(queryMethod, update);
}
if (!ObjectUtils.isEmpty(updateSource.pipeline())) {
AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline());
return aggregationUpdateMethodContributor(queryMethod, update);
}
}
return null;
return queryMethodContributor(queryMethod, query);
}
@SuppressWarnings("NullAway")
@ -193,7 +175,6 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -193,7 +175,6 @@ public class MongoRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(aggregation).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder();
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
builder.add(aggregationBlockBuilder(context, queryMethod).stages(aggregation)
.usingAggregationVariableName("aggregation").build());
@ -209,15 +190,14 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -209,15 +190,14 @@ public class MongoRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder();
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
// update filter
String filterVariableName = update.name();
String filterVariableName = context.localVariable(update.name());
builder.add(queryBlockBuilder(context, queryMethod).filter(update.getFilter())
.usingQueryVariableName(filterVariableName).build());
// update definition
String updateVariableName = "updateDefinition";
String updateVariableName = context.localVariable("updateDefinition");
builder.add(
updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName).build());
@ -233,10 +213,9 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -233,10 +213,9 @@ public class MongoRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder();
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
// update filter
String filterVariableName = update.name();
String filterVariableName = context.localVariable(update.name());
QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(update.getFilter());
builder.add(queryCodeBlockBuilder.usingQueryVariableName(filterVariableName).build());
@ -245,11 +224,12 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -245,11 +224,12 @@ public class MongoRepositoryContributor extends RepositoryContributor {
builder.add(aggregationBlockBuilder(context, queryMethod).stages(update)
.usingAggregationVariableName(updateVariableName).pipelineOnly(true).build());
builder.addStatement("$T aggregationUpdate = $T.from($L.getOperations())", AggregationUpdate.class,
builder.addStatement("$T $L = $T.from($L.getOperations())", AggregationUpdate.class,
context.localVariable("aggregationUpdate"),
AggregationUpdate.class, updateVariableName);
builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName)
.referencingUpdate("aggregationUpdate").build());
.referencingUpdate(context.localVariable("aggregationUpdate")).build());
return builder.build();
});
}
@ -260,11 +240,11 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -260,11 +240,11 @@ public class MongoRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder();
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(query);
builder.add(queryCodeBlockBuilder.usingQueryVariableName(query.name()).build());
builder.add(deleteExecutionBlockBuilder(context, queryMethod).referencing(query.name()).build());
String queryVariableName = context.localVariable(query.name());
builder.add(queryCodeBlockBuilder.usingQueryVariableName(queryVariableName).build());
builder.add(deleteExecutionBlockBuilder(context, queryMethod).referencing(queryVariableName).build());
return builder.build();
});
}
@ -275,12 +255,12 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -275,12 +255,12 @@ public class MongoRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder();
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(query);
builder.add(queryCodeBlockBuilder.usingQueryVariableName(query.name()).build());
builder.add(queryCodeBlockBuilder.usingQueryVariableName(context.localVariable(query.name())).build());
builder.add(queryExecutionBlockBuilder(context, queryMethod).forQuery(query).build());
return builder.build();
});
}
}

5
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java

@ -51,11 +51,14 @@ import org.springframework.util.StringUtils; @@ -51,11 +51,14 @@ import org.springframework.util.StringUtils;
import com.mongodb.client.MongoClient;
/**
* Integration tests for the {@link UserRepository} AOT fragment.
*
* @author Christoph Strobl
* @author Mark Paluch
*/
@ExtendWith(MongoClientExtension.class)
@SpringJUnitConfig(classes = MongoRepositoryContributorTests.MongoRepositoryContributorConfiguration.class)
public class MongoRepositoryContributorTests {
class MongoRepositoryContributorTests {
private static final String DB_NAME = "aot-repo-tests";

9
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java

@ -15,9 +15,9 @@ @@ -15,9 +15,9 @@
*/
package org.springframework.data.mongodb.repository.aot;
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static net.javacrumbs.jsonunit.assertj.JsonAssertions.*;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.*;
import example.aot.UserRepository;
@ -27,6 +27,7 @@ import java.nio.charset.StandardCharsets; @@ -27,6 +27,7 @@ import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@ -176,7 +177,7 @@ class MongoRepositoryMetadataTests { @@ -176,7 +177,7 @@ class MongoRepositoryMetadataTests {
assertThat(resource.exists()).isTrue();
String json = resource.getContentAsString(StandardCharsets.UTF_8);
System.out.println(json);
assertThatJson(json).inPath("$.methods[?(@.name == 'existsById')].fragment").isArray().first().isObject()
.containsEntry("fragment", "org.springframework.data.mongodb.repository.support.SimpleMongoRepository");
}

150
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java

@ -38,81 +38,81 @@ import org.springframework.data.repository.core.support.RepositoryComposition; @@ -38,81 +38,81 @@ import org.springframework.data.repository.core.support.RepositoryComposition;
*/
class TestMongoAotRepositoryContext implements AotRepositoryContext {
private final StubRepositoryInformation repositoryInformation;
private final Environment environment = new StandardEnvironment();
TestMongoAotRepositoryContext(Class<?> repositoryInterface, @Nullable RepositoryComposition composition) {
this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition);
}
@Override
public ConfigurableListableBeanFactory getBeanFactory() {
return null;
}
@Override
public TypeIntrospector introspectType(String typeName) {
return null;
}
@Override
public IntrospectedBeanDefinition introspectBeanDefinition(String beanName) {
return null;
}
@Override
public String getBeanName() {
return "dummyRepository";
}
@Override
public String getModuleName() {
return "MongoDB";
private final StubRepositoryInformation repositoryInformation;
private final Environment environment = new StandardEnvironment();
TestMongoAotRepositoryContext(Class<?> repositoryInterface, @Nullable RepositoryComposition composition) {
this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition);
}
@Override
public ConfigurableListableBeanFactory getBeanFactory() {
return null;
}
@Override
public TypeIntrospector introspectType(String typeName) {
return null;
}
@Override
public IntrospectedBeanDefinition introspectBeanDefinition(String beanName) {
return null;
}
@Override
public String getBeanName() {
return "dummyRepository";
}
@Override
public String getModuleName() {
return "MongoDB";
}
@Override
public Set<String> getBasePackages() {
return Set.of("org.springframework.data.dummy.repository.aot");
}
@Override
public Set<Class<? extends Annotation>> getIdentifyingAnnotations() {
return Set.of(Document.class);
}
@Override
public RepositoryInformation getRepositoryInformation() {
return repositoryInformation;
}
@Override
public Set<MergedAnnotation<Annotation>> getResolvedAnnotations() {
return Set.of();
}
@Override
public Set<Class<?>> getResolvedTypes() {
return Set.of();
}
public List<ClassFile> getRequiredContextFiles() {
return List.of(classFileForType(repositoryInformation.getRepositoryBaseClass()));
}
static ClassFile classFileForType(Class<?> type) {
String name = type.getName();
ClassPathResource cpr = new ClassPathResource(name.replaceAll("\\.", "/") + ".class");
try {
return ClassFile.of(name, cpr.getContentAsByteArray());
} catch (IOException e) {
throw new IllegalArgumentException("Cannot open [%s].".formatted(cpr.getPath()));
}
}
@Override
public Set<String> getBasePackages() {
return Set.of("org.springframework.data.dummy.repository.aot");
}
@Override
public Set<Class<? extends Annotation>> getIdentifyingAnnotations() {
return Set.of(Document.class);
}
@Override
public RepositoryInformation getRepositoryInformation() {
return repositoryInformation;
}
@Override
public Set<MergedAnnotation<Annotation>> getResolvedAnnotations() {
return Set.of();
}
@Override
public Set<Class<?>> getResolvedTypes() {
return Set.of();
}
public List<ClassFile> getRequiredContextFiles() {
return List.of(classFileForType(repositoryInformation.getRepositoryBaseClass()));
}
static ClassFile classFileForType(Class<?> type) {
String name = type.getName();
ClassPathResource cpr = new ClassPathResource(name.replaceAll("\\.", "/") + ".class");
try {
return ClassFile.of(name, cpr.getContentAsByteArray());
} catch (IOException e) {
throw new IllegalArgumentException("Cannot open [%s].".formatted(cpr.getPath()));
}
}
@Override
public Environment getEnvironment() {
return environment;
}
@Override
public Environment getEnvironment() {
return environment;
}
}

Loading…
Cancel
Save