@ -19,6 +19,7 @@ import java.util.ArrayList;
@@ -19,6 +19,7 @@ import java.util.ArrayList;
import java.util.List ;
import java.util.Optional ;
import java.util.regex.Pattern ;
import java.util.stream.Stream ;
import org.bson.Document ;
import org.jspecify.annotations.NullUnmarked ;
@ -49,7 +50,6 @@ import org.springframework.data.mongodb.repository.query.MongoQueryExecution.Sli
@@ -49,7 +50,6 @@ import org.springframework.data.mongodb.repository.query.MongoQueryExecution.Sli
import org.springframework.data.mongodb.repository.query.MongoQueryMethod ;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext ;
import org.springframework.data.util.ReflectionUtils ;
import org.springframework.javapoet.ClassName ;
import org.springframework.javapoet.CodeBlock ;
import org.springframework.javapoet.CodeBlock.Builder ;
import org.springframework.javapoet.TypeName ;
@ -182,17 +182,15 @@ class MongoCodeBlocks {
@@ -182,17 +182,15 @@ class MongoCodeBlocks {
String mongoOpsRef = context . fieldNameOf ( MongoOperations . class ) ;
Builder builder = CodeBlock . builder ( ) ;
Class < ? > domainType = context . getRepositoryInformation ( ) . getDomainType ( ) ;
boolean isProjecting = context . getActualReturnType ( ) ! = null
& & ! ObjectUtils . nullSafeEquals ( TypeName . get ( context . getRepositoryInformation ( ) . getDomainType ( ) ) ,
context . getActualReturnType ( ) ) ;
& & ! ObjectUtils . nullSafeEquals ( TypeName . get ( domainType ) , context . getActualReturnType ( ) ) ;
Object actualReturnType = isProjecting ? context . getActualReturnType ( ) . getType ( )
: context . getRepositoryInformation ( ) . getDomainType ( ) ;
Object actualReturnType = isProjecting ? context . getActualReturnType ( ) . getType ( ) : domainType ;
builder . add ( "\n" ) ;
builder . addStatement ( "$T<$T> remover = $L.remove($T.class)" , ExecutableRemove . class ,
context . getRepositoryInformation ( ) . getDomainType ( ) , mongoOpsRef ,
context . getRepositoryInformation ( ) . getDomainType ( ) ) ;
builder . addStatement ( "$T<$T> $L = $L.remove($T.class)" , ExecutableRemove . class , domainType ,
context . localVariable ( "remover" ) , mongoOpsRef , domainType ) ;
DeleteExecution . Type type = DeleteExecution . Type . FIND_AND_REMOVE_ALL ;
if ( ! queryMethod . isCollectionQuery ( ) ) {
@ -204,11 +202,20 @@ class MongoCodeBlocks {
@@ -204,11 +202,20 @@ class MongoCodeBlocks {
}
actualReturnType = ClassUtils . isPrimitiveOrWrapper ( context . getMethod ( ) . getReturnType ( ) )
? Class Name. get ( context . getMethod ( ) . getReturnType ( ) )
? Type Name. get ( context . getMethod ( ) . getReturnType ( ) )
: queryMethod . isCollectionQuery ( ) ? context . getReturnTypeName ( ) : actualReturnType ;
builder . addStatement ( "return ($T) new $T(remover, $T.$L).execute($L)" , actualReturnType , DeleteExecution . class ,
DeleteExecution . Type . class , type . name ( ) , queryVariableName ) ;
if ( ClassUtils . isVoidType ( context . getMethod ( ) . getReturnType ( ) ) ) {
builder . addStatement ( "new $T($L, $T.$L).execute($L)" , DeleteExecution . class , context . localVariable ( "remover" ) ,
DeleteExecution . Type . class , type . name ( ) , queryVariableName ) ;
} else if ( context . getMethod ( ) . getReturnType ( ) = = Optional . class ) {
builder . addStatement ( "return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))" , Optional . class ,
actualReturnType , DeleteExecution . class , context . localVariable ( "remover" ) , DeleteExecution . Type . class ,
type . name ( ) , queryVariableName ) ;
} else {
builder . addStatement ( "return ($T) new $T($L, $T.$L).execute($L)" , actualReturnType , DeleteExecution . class ,
context . localVariable ( "remover" ) , DeleteExecution . Type . class , type . name ( ) , queryVariableName ) ;
}
return builder . build ( ) ;
}
@ -318,14 +325,25 @@ class MongoCodeBlocks {
@@ -318,14 +325,25 @@ class MongoCodeBlocks {
Class < ? > returnType = ClassUtils . resolvePrimitiveIfNecessary ( queryMethod . getReturnedObjectType ( ) ) ;
builder . addStatement ( "$T $L = $L.aggregate($L, $T.class)" , AggregationResults . class ,
context . localVariable ( "results" ) , mongoOpsRef , aggregationVariableName , outputType ) ;
if ( ! queryMethod . isCollectionQuery ( ) ) {
builder . addStatement ( "return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))" ,
CollectionUtils . class , returnType , returnType , context . localVariable ( "results" ) ) ;
if ( queryMethod . isStreamQuery ( ) ) {
builder . addStatement ( "$T<$T> $L = $L.aggregateStream($L, $T.class)" , Stream . class , Document . class ,
context . localVariable ( "results" ) , mongoOpsRef , aggregationVariableName , outputType ) ;
builder . addStatement ( "return $L.map(it -> ($T) convertSimpleRawResult($T.class, it))" ,
context . localVariable ( "results" ) , returnType , returnType ) ;
} else {
builder . addStatement ( "return convertSimpleRawResults($T.class, $L.getMappedResults())" , returnType ,
context . localVariable ( "results" ) ) ;
builder . addStatement ( "$T $L = $L.aggregate($L, $T.class)" , AggregationResults . class ,
context . localVariable ( "results" ) , mongoOpsRef , aggregationVariableName , outputType ) ;
if ( ! queryMethod . isCollectionQuery ( ) ) {
builder . addStatement ( "return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))" ,
CollectionUtils . class , returnType , returnType , context . localVariable ( "results" ) ) ;
} else {
builder . addStatement ( "return convertSimpleRawResults($T.class, $L.getMappedResults())" , returnType ,
context . localVariable ( "results" ) ) ;
}
}
} else {
if ( queryMethod . isSliceQuery ( ) ) {
@ -339,8 +357,15 @@ class MongoCodeBlocks {
@@ -339,8 +357,15 @@ class MongoCodeBlocks {
context . getPageableParameterName ( ) , context . localVariable ( "results" ) , context . getPageableParameterName ( ) ,
context . localVariable ( "hasNext" ) ) ;
} else {
builder . addStatement ( "return $L.aggregate($L, $T.class).getMappedResults()" , mongoOpsRef ,
aggregationVariableName , outputType ) ;
if ( queryMethod . isStreamQuery ( ) ) {
builder . addStatement ( "return $L.aggregateStream($L, $T.class)" , mongoOpsRef , aggregationVariableName ,
outputType ) ;
} else {
builder . addStatement ( "return $L.aggregate($L, $T.class).getMappedResults()" , mongoOpsRef ,
aggregationVariableName , outputType ) ;
}
}
}
@ -420,8 +445,16 @@ class MongoCodeBlocks {
@@ -420,8 +445,16 @@ class MongoCodeBlocks {
builder . addStatement ( "return $L.matching($L).scroll($L)" , context . localVariable ( "finder" ) , query . name ( ) ,
scrollPositionParameterName ) ;
} else {
builder . addStatement ( "return $L.matching($L).$L" , context . localVariable ( "finder" ) , query . name ( ) ,
terminatingMethod ) ;
if ( query . isCount ( ) & & ! ClassUtils . isAssignable ( Long . class , context . getActualReturnType ( ) . getRawClass ( ) ) ) {
Class < ? > returnType = ClassUtils . resolvePrimitiveIfNecessary ( queryMethod . getReturnedObjectType ( ) ) ;
builder . addStatement ( "return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)" , NumberUtils . class ,
context . localVariable ( "finder" ) , query . name ( ) , terminatingMethod , returnType ) ;
} else {
builder . addStatement ( "return $L.matching($L).$L" , context . localVariable ( "finder" ) , query . name ( ) ,
terminatingMethod ) ;
}
}
return builder . build ( ) ;