@ -246,20 +246,24 @@ class MongoCodeBlocks {
@@ -246,20 +246,24 @@ class MongoCodeBlocks {
builder . add ( "\n" ) ;
String updateReference = updateVariableName ;
builder . addStatement ( "$T<$T> updater = $L.update($T.class)" , ExecutableUpdate . class ,
context . getRepositoryInformation ( ) . getDomainType ( ) , mongoOpsRef ,
context . getRepositoryInformation ( ) . getDomainType ( ) ) ;
Class < ? > domainType = context . getRepositoryInformation ( ) . getDomainType ( ) ;
builder . addStatement ( "$T<$T> $L = $L.update($T.class)" , ExecutableUpdate . class , domainType ,
context . localVariable ( "updater" ) , mongoOpsRef , domainType ) ;
Class < ? > returnType = ClassUtils . resolvePrimitiveIfNecessary ( queryMethod . getReturnedObjectType ( ) ) ;
if ( ReflectionUtils . isVoid ( returnType ) ) {
builder . addStatement ( "updater.matching($L).apply($L).all()" , queryVariableName , updateReference ) ;
builder . addStatement ( "$L.matching($L).apply($L).all()" , context . localVariable ( "updater" ) , queryVariableName ,
updateReference ) ;
} else if ( ClassUtils . isAssignable ( Long . class , returnType ) ) {
builder . addStatement ( "return updater.matching($L).apply($L).all().getModifiedCount()" , queryVariableName ,
builder . addStatement ( "return $L.matching($L).apply($L).all().getModifiedCount()" ,
context . localVariable ( "updater" ) , queryVariableName ,
updateReference ) ;
} else {
builder . addStatement ( "$T modifiedCount = updater.matching($L).apply($L).all().getModifiedCount()" , Long . class ,
builder . addStatement ( "$T $L = $L.matching($L).apply($L).all().getModifiedCount()" , Long . class ,
context . localVariable ( "modifiedCount" ) , context . localVariable ( "updater" ) ,
queryVariableName , updateReference ) ;
builder . addStatement ( "return $T.convertNumberToTargetClass(modifiedCount, $T.class)" , NumberUtils . class ,
builder . addStatement ( "return $T.convertNumberToTargetClass($L, $T.class)" , NumberUtils . class ,
context . localVariable ( "modifiedCount" ) ,
returnType ) ;
}
@ -314,24 +318,29 @@ class MongoCodeBlocks {
@@ -314,24 +318,29 @@ class MongoCodeBlocks {
Class < ? > returnType = ClassUtils . resolvePrimitiveIfNecessary ( queryMethod . getReturnedObjectType ( ) ) ;
builder . addStatement ( "$T results = $L.aggregate($L, $T.class)" , AggregationResults . class , mongoOpsRef ,
builder . addStatement ( "$T $L = $L.aggregate($L, $T.class)" , AggregationResults . class ,
context . localVariable ( "results" ) , mongoOpsRef ,
aggregationVariableName , outputType ) ;
if ( ! queryMethod . isCollectionQuery ( ) ) {
builder . addStatement (
"return $T.<$T>firstElement(convertSimpleRawResults($T.class, results .getMappedResults()))" ,
CollectionUtils . class , returnType , returnType ) ;
"return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L .getMappedResults()))" ,
CollectionUtils . class , returnType , returnType , context . localVariable ( "results" ) ) ;
} else {
builder . addStatement ( "return convertSimpleRawResults($T.class, results.getMappedResults())" , returnType ) ;
builder . addStatement ( "return convertSimpleRawResults($T.class, $L.getMappedResults())" , returnType ,
context . localVariable ( "results" ) ) ;
}
} else {
if ( queryMethod . isSliceQuery ( ) ) {
builder . addStatement ( "$T results = $L.aggregate($L, $T.class)" , AggregationResults . class , mongoOpsRef ,
builder . addStatement ( "$T $L = $L.aggregate($L, $T.class)" , AggregationResults . class ,
context . localVariable ( "results" ) , mongoOpsRef ,
aggregationVariableName , outputType ) ;
builder . addStatement ( "boolean hasNext = results .getMappedResults().size() > $L.getPageSize()" ,
context . getPageableParameterName ( ) ) ;
builder . addStatement ( "boolean $L = $L .getMappedResults().size() > $L.getPageSize()" ,
context . localVariable ( "hasNext" ) , context . localVariable ( "results" ) , context . getPageableParameterName ( ) ) ;
builder . addStatement (
"return new $T<>(hasNext ? results.getMappedResults().subList(0, $L.getPageSize()) : results.getMappedResults(), $L, hasNext)" ,
SliceImpl . class , context . getPageableParameterName ( ) , context . getPageableParameterName ( ) ) ;
"return new $T<>($L ? $L.getMappedResults().subList(0, $L.getPageSize()) : $L.getMappedResults(), $L, $L)" ,
SliceImpl . class , context . localVariable ( "hasNext" ) , context . localVariable ( "results" ) ,
context . getPageableParameterName ( ) , context . localVariable ( "results" ) , context . getPageableParameterName ( ) ,
context . localVariable ( "hasNext" ) ) ;
} else {
builder . addStatement ( "return $L.aggregate($L, $T.class).getMappedResults()" , mongoOpsRef ,
aggregationVariableName , outputType ) ;
@ -368,18 +377,19 @@ class MongoCodeBlocks {
@@ -368,18 +377,19 @@ class MongoCodeBlocks {
Builder builder = CodeBlock . builder ( ) ;
boolean isProjecting = context . getReturnedType ( ) . isProjecting ( ) ;
Class < ? > domainType = context . getRepositoryInformation ( ) . getDomainType ( ) ;
Object actualReturnType = isProjecting ? context . getActualReturnType ( ) . getType ( )
: context . getRepositoryInformation ( ) . getDomainType ( ) ;
: domainType ;
builder . add ( "\n" ) ;
if ( isProjecting ) {
builder . addStatement ( "$T<$T> finder = $L.query($T.class).as($T.class)" , FindWithQuery . class , actualReturnType ,
mongoOpsRef , context . getRepositoryInformation ( ) . getDomainType ( ) , actualReturnType ) ;
builder . addStatement ( "$T<$T> $L = $L.query($T.class).as($T.class)" , FindWithQuery . class , actualReturnType ,
context . localVariable ( "finder" ) , mongoOpsRef , domainType , actualReturnType ) ;
} else {
builder . addStatement ( "$T<$T> finder = $L.query($T.class)" , FindWithQuery . class , actualReturnType , mongoOpsRef ,
context . getRepositoryInformation ( ) . getDomainType ( ) ) ;
builder . addStatement ( "$T<$T> $L = $L.query($T.class)" , FindWithQuery . class , actualReturnType ,
context . localVariable ( "finder" ) , mongoOpsRef , domainType ) ;
}
String terminatingMethod ;
@ -395,13 +405,14 @@ class MongoCodeBlocks {
@@ -395,13 +405,14 @@ class MongoCodeBlocks {
}
if ( queryMethod . isPageQuery ( ) ) {
builder . addStatement ( "return new $T(finder , $L).execute($L)" , PagedExecution . class ,
builder . addStatement ( "return new $T($L , $L).execute($L)" , PagedExecution . class , context . localVariable ( "finder" ) ,
context . getPageableParameterName ( ) , query . name ( ) ) ;
} else if ( queryMethod . isSliceQuery ( ) ) {
builder . addStatement ( "return new $T(finder , $L).execute($L)" , SlicedExecution . class ,
context . getPageableParameterName ( ) , query . name ( ) ) ;
builder . addStatement ( "return new $T($L , $L).execute($L)" , SlicedExecution . class ,
context . localVariable ( "finder" ) , context . getPageableParameterName ( ) , query . name ( ) ) ;
} else {
builder . addStatement ( "return finder.matching($L).$L" , query . name ( ) , terminatingMethod ) ;
builder . addStatement ( "return $L.matching($L).$L" , context . localVariable ( "finder" ) , query . name ( ) ,
terminatingMethod ) ;
}
return builder . build ( ) ;
@ -415,7 +426,7 @@ class MongoCodeBlocks {
@@ -415,7 +426,7 @@ class MongoCodeBlocks {
private final MongoQueryMethod queryMethod ;
private AggregationInteraction source ;
private List < String > arguments ;
private final List < String > arguments ;
private String aggregationVariableName ;
private boolean pipelineOnly ;
@ -449,7 +460,7 @@ class MongoCodeBlocks {
@@ -449,7 +460,7 @@ class MongoCodeBlocks {
CodeBlock . Builder builder = CodeBlock . builder ( ) ;
builder . add ( "\n" ) ;
String pipelineName = aggregationVariableName + ( pipelineOnly ? "" : "Pipeline" ) ;
String pipelineName = context . localVariable ( aggregationVariableName + ( pipelineOnly ? "" : "Pipeline" ) ) ;
builder . add ( pipeline ( pipelineName ) ) ;
if ( ! pipelineOnly ) {
@ -486,8 +497,7 @@ class MongoCodeBlocks {
@@ -486,8 +497,7 @@ class MongoCodeBlocks {
}
Builder builder = CodeBlock . builder ( ) ;
String stagesVariableName = "stages" ;
builder . add ( aggregationStages ( stagesVariableName , source . stages ( ) , stageCount , arguments ) ) ;
builder . add ( aggregationStages ( context . localVariable ( "stages" ) , source . stages ( ) , stageCount , arguments ) ) ;
if ( mightBeSorted ) {
builder . add ( sortingStage ( sortParameter ) ) ;
@ -502,7 +512,7 @@ class MongoCodeBlocks {
@@ -502,7 +512,7 @@ class MongoCodeBlocks {
}
builder . addStatement ( "$T $L = createPipeline($L)" , AggregationPipeline . class , pipelineVariableName ,
stagesVariableName ) ;
context . localVariable ( "stages" ) ) ;
return builder . build ( ) ;
}
@ -533,7 +543,8 @@ class MongoCodeBlocks {
@@ -533,7 +543,8 @@ class MongoCodeBlocks {
if ( ! options . isEmpty ( ) ) {
Builder optionsBuilder = CodeBlock . builder ( ) ;
optionsBuilder . add ( "$T aggregationOptions = $T.builder()\n" , AggregationOptions . class ,
optionsBuilder . add ( "$T $L = $T.builder()\n" , AggregationOptions . class ,
context . localVariable ( "aggregationOptions" ) ,
AggregationOptions . class ) ;
optionsBuilder . indent ( ) ;
for ( CodeBlock optionBlock : options ) {
@ -544,67 +555,81 @@ class MongoCodeBlocks {
@@ -544,67 +555,81 @@ class MongoCodeBlocks {
optionsBuilder . unindent ( ) ;
builder . add ( optionsBuilder . build ( ) ) ;
builder . addStatement ( "$L = $L.withOptions(aggregationOptions)" , aggregationVariableName ,
aggregationVariableName ) ;
builder . addStatement ( "$L = $L.withOptions($L)" , aggregationVariableName , aggregationVariableName ,
context . localVariable ( "aggregationOptions" ) ) ;
}
return builder . build ( ) ;
}
private static CodeBlock aggregationStages ( String stageListVariableName , Iterable < String > stages , int stageCount ,
private CodeBlock aggregationStages ( String stageListVariableName , Iterable < String > stages , int stageCount ,
List < String > arguments ) {
Builder builder = CodeBlock . builder ( ) ;
builder . addStatement ( "$T<$T> $L = new $T($L)" , List . class , Object . class , stageListVariableName , ArrayList . class ,
stageCount ) ;
int stageCounter = 0 ;
for ( String stage : stages ) {
String stageName = "stage_%s" . formatted ( stageCounter + + ) ;
String stageName = context . localVariable ( "stage_%s" . formatted ( stageCounter + + ) ) ;
builder . add ( renderExpressionToDocument ( stage , stageName , arguments ) ) ;
builder . addStatement ( "stages.add($L)" , stageName ) ;
builder . addStatement ( "$L.add($L)" , context . localVariable ( "stages" ) , stageName ) ;
}
return builder . build ( ) ;
}
private static CodeBlock sortingStage ( String sortProvider ) {
private CodeBlock sortingStage ( String sortProvider ) {
Builder builder = CodeBlock . builder ( ) ;
builder . beginControlFlow ( "if($L.isSorted())" , sortProvider ) ;
builder . addStatement ( "$T sortDocument = new $T()" , Document . class , Document . class ) ;
builder . beginControlFlow ( "for ($T order : $L)" , Order . class , sortProvider ) ;
builder . addStatement ( "sortDocument.append(order.getProperty(), order.isAscending() ? 1 : -1);" ) ;
builder . beginControlFlow ( "if ($L.isSorted())" , sortProvider ) ;
builder . addStatement ( "$T $L = new $T()" , Document . class , context . localVariable ( "sortDocument" ) , Document . class ) ;
builder . beginControlFlow ( "for ($T $L : $L)" , Order . class , context . localVariable ( "order" ) , sortProvider ) ;
builder . addStatement ( "$L.append($L.getProperty(), $L.isAscending() ? 1 : -1);" ,
context . localVariable ( "sortDocument" ) , context . localVariable ( "order" ) , context . localVariable ( "order" ) ) ;
builder . endControlFlow ( ) ;
builder . addStatement ( "stages.add(new $T($S, sortDocument))" , Document . class , "$sort" ) ;
builder . addStatement ( "stages.add(new $T($S, $L))" , Document . class , "$sort" ,
context . localVariable ( "sortDocument" ) ) ;
builder . endControlFlow ( ) ;
return builder . build ( ) ;
}
private static CodeBlock pagingStage ( String pageableProvider , boolean slice ) {
private CodeBlock pagingStage ( String pageableProvider , boolean slice ) {
Builder builder = CodeBlock . builder ( ) ;
builder . add ( sortingStage ( pageableProvider + ".getSort()" ) ) ;
builder . beginControlFlow ( "if($L.isPaged())" , pageableProvider ) ;
builder . beginControlFlow ( "if($L.getOffset() > 0)" , pageableProvider ) ;
builder . addStatement ( "stages.add($T.skip($L.getOffset()))" , Aggregation . class , pageableProvider ) ;
builder . beginControlFlow ( "if ($L.isPaged())" , pageableProvider ) ;
builder . beginControlFlow ( "if ($L.getOffset() > 0)" , pageableProvider ) ;
builder . addStatement ( "$L.add($T.skip($L.getOffset()))" , context . localVariable ( "stages" ) , Aggregation . class ,
pageableProvider ) ;
builder . endControlFlow ( ) ;
if ( slice ) {
builder . addStatement ( "stages.add($T.limit($L.getPageSize() + 1))" , Aggregation . class , pageableProvider ) ;
builder . addStatement ( "$L.add($T.limit($L.getPageSize() + 1))" , context . localVariable ( "stages" ) ,
Aggregation . class , pageableProvider ) ;
} else {
builder . addStatement ( "stages.add($T.limit($L.getPageSize()))" , Aggregation . class , pageableProvider ) ;
builder . addStatement ( "$L.add($T.limit($L.getPageSize()))" , context . localVariable ( "stages" ) , Aggregation . class ,
pageableProvider ) ;
}
builder . endControlFlow ( ) ;
return builder . build ( ) ;
}
private static CodeBlock limitingStage ( String limitProvider ) {
private CodeBlock limitingStage ( String limitProvider ) {
Builder builder = CodeBlock . builder ( ) ;
builder . beginControlFlow ( "if($L.isLimited())" , limitProvider ) ;
builder . addStatement ( "stages.add($T.limit($L.max()))" , Aggregation . class , limitProvider ) ;
builder . beginControlFlow ( "if ($L.isLimited())" , limitProvider ) ;
builder . addStatement ( "$L.add($T.limit($L.max()))" , context . localVariable ( "stages" ) , Aggregation . class ,
limitProvider ) ;
builder . endControlFlow ( ) ;
return builder . build ( ) ;
}
}
@NullUnmarked
@ -614,7 +639,7 @@ class MongoCodeBlocks {
@@ -614,7 +639,7 @@ class MongoCodeBlocks {
private final MongoQueryMethod queryMethod ;
private QueryInteraction source ;
private List < String > arguments ;
private final List < String > arguments ;
private String queryVariableName ;
QueryCodeBlockBuilder ( AotQueryMethodGenerationContext context , MongoQueryMethod queryMethod ) {
@ -697,17 +722,10 @@ class MongoCodeBlocks {
@@ -697,17 +722,10 @@ class MongoCodeBlocks {
builder . addStatement ( "$T $L = new $T(new $T())" , BasicQuery . class , variableName , BasicQuery . class ,
Document . class ) ;
} else if ( ! containsPlaceholder ( source ) ) {
String tmpVarName = "%sString" . formatted ( variableName ) ;
builder . addStatement ( "String $L = $S" , tmpVarName , source ) ;
builder . addStatement ( "$T $L = new $T($T.parse($L))" , BasicQuery . class , variableName , BasicQuery . class ,
Document . class , tmpVarName ) ;
builder . addStatement ( "$T $L = new $T($T.parse($S))" , BasicQuery . class , variableName , BasicQuery . class ,
Document . class , source ) ;
} else {
String tmpVarName = "%sString" . formatted ( variableName ) ;
builder . addStatement ( "String $L = $S" , tmpVarName , source ) ;
builder . addStatement ( "$T $L = createQuery($L, new $T[]{ $L })" , BasicQuery . class , variableName , tmpVarName ,
builder . addStatement ( "$T $L = createQuery($S, new $T[]{ $L })" , BasicQuery . class , variableName , source ,
Object . class , StringUtils . collectionToDelimitedString ( arguments , ", " ) ) ;
}
@ -757,15 +775,9 @@ class MongoCodeBlocks {
@@ -757,15 +775,9 @@ class MongoCodeBlocks {
if ( ! StringUtils . hasText ( source ) ) {
builder . addStatement ( "$T $L = new $T()" , Document . class , variableName , Document . class ) ;
} else if ( ! containsPlaceholder ( source ) ) {
String tmpVarName = "%sString" . formatted ( variableName ) ;
builder . addStatement ( "String $L = $S" , tmpVarName , source ) ;
builder . addStatement ( "$T $L = $T.parse($L)" , Document . class , variableName , Document . class , tmpVarName ) ;
builder . addStatement ( "$T $L = $T.parse($S)" , Document . class , variableName , Document . class , source ) ;
} else {
String tmpVarName = "%sString" . formatted ( variableName ) ;
builder . addStatement ( "String $L = $S" , tmpVarName , source ) ;
builder . addStatement ( "$T $L = bindParameters($L, new $T[]{ $L })" , Document . class , variableName , tmpVarName ,
builder . addStatement ( "$T $L = bindParameters($S, new $T[]{ $L })" , Document . class , variableName , source ,
Object . class , StringUtils . collectionToDelimitedString ( arguments , ", " ) ) ;
}
return builder . build ( ) ;