@ -29,9 +29,7 @@ import javax.lang.model.element.Modifier;
@@ -29,9 +29,7 @@ import javax.lang.model.element.Modifier;
import org.apache.commons.logging.Log ;
import org.apache.commons.logging.LogFactory ;
import org.jspecify.annotations.Nullable ;
import org.springframework.aot.generate.Generated ;
import org.springframework.aot.generate.GeneratedTypeReference ;
import org.springframework.aot.hint.TypeReference ;
import org.springframework.core.ResolvableType ;
import org.springframework.data.projection.ProjectionFactory ;
import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata.ConstructorArgument ;
@ -42,7 +40,6 @@ import org.springframework.data.repository.query.QueryMethod;
@@ -42,7 +40,6 @@ import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.util.Lazy ;
import org.springframework.javapoet.ClassName ;
import org.springframework.javapoet.FieldSpec ;
import org.springframework.javapoet.JavaFile ;
import org.springframework.javapoet.MethodSpec ;
import org.springframework.javapoet.TypeSpec ;
import org.springframework.util.Assert ;
@ -90,89 +87,53 @@ class AotRepositoryCreator {
@@ -90,89 +87,53 @@ class AotRepositoryCreator {
}
/ * *
* Configure a { @link AotRepositoryConstructorBuilder } customizer .
* Get the { @link ClassName } for the AOT repository fragment .
*
* @param classCustomizer must not be { @literal null } .
* @return { @code this } .
* @return the { @link ClassName } for the AOT repository fragment .
* /
AotRepositoryCreator customizeClass ( Consumer < AotRepositoryClassBuilder > classCustomizer ) {
this . classCustomizer = classCustomizer ;
return this ;
ClassName getClassName ( ) {
return ClassName . get ( packageName ( ) ,
"%sImpl" . formatted ( repositoryInformation . getRepositoryInterface ( ) . getSimpleName ( ) ) ) ;
}
/ * *
* Configure a { @link AotRepositoryConstructorBuilder } customizer .
*
* @param constructorCustomizer must not be { @literal null } .
* @return { @code this } .
* /
@SuppressWarnings ( "NullAway" )
AotRepositoryCreator customizeConstructor ( Consumer < AotRepositoryConstructorBuilder > constructorCustomizer ) {
String packageName ( ) {
return repositoryInformation . getRepositoryInterface ( ) . getPackageName ( ) ;
}
if ( constructorBuilder ! = null ) {
constructorBuilder . dispose ( ) ;
Map < String , ResolvableType > getAutowireFields ( ) {
Map < String , ResolvableType > autowireFields = new LinkedHashMap < > (
generationMetadata . getConstructorArguments ( ) . size ( ) ) ;
for ( Map . Entry < String , ConstructorArgument > entry : generationMetadata . getConstructorArguments ( ) . entrySet ( ) ) {
autowireFields . put ( entry . getKey ( ) , entry . getValue ( ) . parameterType ( ) ) ;
}
return autowireFields ;
}
RepositoryConstructorBuilder constructorBuilder = new RepositoryConstructorBuilder ( generationMetadata ) ;
constructorCustomizer . accept ( constructorBuilder ) ;
this . constructorBuilder = constructorBuilder ;
return this ;
RepositoryInformation getRepositoryInformation ( ) {
return repositoryInformation ;
}
AotRepositoryCreator resolveQueryMethods ( ) {
return resolveQueryMethods ( new MethodContributorFactory ( ) {
@Override
public @Nullable MethodContributor < ? extends QueryMethod > create ( Method method ) {
return null ;
}
} ) ;
ProjectionFactory getProjectionFactory ( ) {
return projectionFactory ;
}
/ * *
* Configure a { @link MethodContributor } factory .
* Create the AOT repository fragment and add constructors , methods and fields to the given { @link TypeSpec . Builder } .
*
* @param methodContributorFactory must not be { @literal null } .
* @return { @code this } .
* @param target the target { @link TypeSpec . Builder } to which the AOT repository fragment will be added .
* @return an
* /
AotRepositoryCreator resolveQueryMethods ( @Nullable MethodContributorFactory methodContributorFactory ) {
Arrays . stream ( repositoryInformation . getRepositoryInterface ( ) . getMethods ( ) )
. sorted ( Comparator . < Method , String > comparing ( it - > {
return it . getDeclaringClass ( ) . getName ( ) ;
} ) . thenComparing ( Method : : getName ) . thenComparing ( Method : : getParameterCount ) . thenComparing ( Method : : toString ) )
. forEach ( method - > {
RepositoryComposition repositoryComposition = repositoryInformation . getRepositoryComposition ( ) ;
try {
resolveQueryMethod ( method , methodContributorFactory , repositoryComposition , generationMetadata ) ;
} catch ( RuntimeException e ) {
if ( logger . isErrorEnabled ( ) ) {
logger . error ( "Failed to contribute Repository method [%s.%s]"
. formatted ( repositoryInformation . getRepositoryInterface ( ) . getName ( ) , method . getName ( ) ) , e ) ;
}
}
} ) ;
return this ;
}
AotBundle create ( ) {
return create ( repositoryImplementationTypeName ( ) ) ;
}
AotBundle create ( String targetTypeName ) {
return create ( TypeSpec . classBuilder ( ClassName . bestGuess ( targetTypeName ) ) . addAnnotation ( Generated . class ) ) ;
}
AotBundle create ( TypeSpec . Builder builder ) {
AotBundle create ( TypeSpec . Builder target ) {
List < AotRepositoryMethod > methodMetadata = new ArrayList < > ( ) ;
builder . addModifiers ( Modifier . PUBLIC ) //
target . addModifiers ( Modifier . PUBLIC ) //
. addJavadoc ( "AOT generated $L repository implementation for {@link $T}.\n" , moduleName ,
repositoryInformation . getRepositoryInterface ( ) ) ;
// create the constructor
builder . addMethod ( buildConstructor ( ) ) ;
target . addMethod ( buildConstructor ( ) ) ;
generationMetadata . getMethods ( ) . values ( ) . forEach ( localMethod - > {
@ -182,7 +143,7 @@ class AotRepositoryCreator {
@@ -182,7 +143,7 @@ class AotRepositoryCreator {
MethodSpec methodSpec = methodContributor . contribute ( context ) ;
if ( methodSpec ! = null ) {
builder . addMethod ( methodSpec ) ;
target . addMethod ( methodSpec ) ;
}
// TODO: decouple json from method building and get rid of methodMetadata here?
@ -209,24 +170,19 @@ class AotRepositoryCreator {
@@ -209,24 +170,19 @@ class AotRepositoryCreator {
// write fields at the end so we make sure to capture things added by methods
generationMetadata . getFields ( ) . values ( ) . stream ( )
. map ( field - > FieldSpec . builder ( field . fieldType ( ) . getType ( ) , field . fieldName ( ) , field . modifiers ( ) ) . build ( ) )
. forEach ( builder : : addField ) ;
. forEach ( target : : addField ) ;
// finally customize the file itself
this . classCustomizer . accept ( customizer - > {
Assert . notNull ( customizer , "ClassCustomizer must not be null" ) ;
customizer . customize ( builder ) ;
customizer . customize ( target ) ;
} ) ;
return new AotBundle ( repositoryInformation . getRepositoryInterface ( ) ,
Lazy . of ( ( ) - > JavaFile . builder ( packageName ( ) , builder . build ( ) ) . build ( ) ) ,
Lazy . of ( ( ) - > getAotRepositoryMetadata ( methodMetadata ) ) ) ;
}
String repositoryImplementationTypeName ( ) {
return "%s.%s" . formatted ( packageName ( ) , typeName ( ) ) ;
}
private MethodSpec buildConstructor ( ) {
return constructorBuilder ! = null ? constructorBuilder . buildConstructor ( )
: MethodSpec . constructorBuilder ( ) . addModifiers ( Modifier . PUBLIC ) . build ( ) ;
@ -244,63 +200,117 @@ class AotRepositoryCreator {
@@ -244,63 +200,117 @@ class AotRepositoryCreator {
repositoryType , methodMetadata ) ;
}
private void resolveQueryMethod ( Method method , @Nullable MethodContributorFactory contributorFactory ,
RepositoryComposition repositoryComposition , AotRepositoryFragmentMetadata metadata ) {
/ * *
* Configure a { @link AotRepositoryConstructorBuilder } customizer .
*
* @param classCustomizer must not be { @literal null } .
* @return { @code this } .
* /
AotRepositoryCreator customizeClass ( Consumer < AotRepositoryClassBuilder > classCustomizer ) {
this . classCustomizer = classCustomizer ;
return this ;
}
/ * *
* Configure a { @link AotRepositoryConstructorBuilder } customizer .
*
* @param constructorCustomizer must not be { @literal null } .
* @return { @code this } .
* /
@SuppressWarnings ( "NullAway" )
AotRepositoryCreator customizeConstructor ( Consumer < AotRepositoryConstructorBuilder > constructorCustomizer ) {
if ( constructorBuilder ! = null ) {
constructorBuilder . dispose ( ) ;
}
RepositoryConstructorBuilder constructorBuilder = new RepositoryConstructorBuilder ( generationMetadata ) ;
constructorCustomizer . accept ( constructorBuilder ) ;
this . constructorBuilder = constructorBuilder ;
return this ;
}
/ * *
* Contribute repository methods using { @link MethodContributor } factory .
*
* @param methodContributorFactory must not be { @literal null } .
* /
void contributeMethods ( @Nullable MethodContributorFactory methodContributorFactory ) {
Arrays . stream ( repositoryInformation . getRepositoryInterface ( ) . getMethods ( ) )
. sorted ( Comparator . < Method , String > comparing ( it - > it . getDeclaringClass ( ) . getName ( ) ) //
. thenComparing ( Method : : getName ) //
. thenComparing ( Method : : getParameterCount ) //
. thenComparing ( Method : : toString ) )
. forEach ( method - > {
try {
contributeMethod ( method , methodContributorFactory ) ;
} catch ( RuntimeException e ) {
if ( logger . isErrorEnabled ( ) ) {
logger . error ( "Failed to contribute Repository method [%s.%s]"
. formatted ( repositoryInformation . getRepositoryInterface ( ) . getName ( ) , method . getName ( ) ) , e ) ;
}
}
} ) ;
}
private void contributeMethod ( Method method , @Nullable MethodContributorFactory contributorFactory ) {
if ( repositoryInformation . isCustomMethod ( method )
| | ( repositoryInformation . isBaseClassMethod ( method ) & & ! repositoryInformation . isQueryMethod ( method ) ) ) {
RepositoryComposition repositoryComposition = repositoryInformation . getRepositoryComposition ( ) ;
RepositoryFragment < ? > fragment = repositoryComposition . findFragment ( method ) ;
if ( fragment ! = null ) {
metadata . addDelegateMethod ( method , fragment ) ;
generationM etadata. addDelegateMethod ( method , fragment ) ;
return ;
}
}
if ( method . isBridge ( ) | | method . isDefault ( ) | | java . lang . reflect . Modifier . isStatic ( method . getModifiers ( ) ) ) {
if ( logger . isTraceEnabled ( ) ) {
logger . trace ( "Skipping %s method [%s.%s] contribution" . formatted (
( method . isBridge ( ) ? "bridge" : method . isDefault ( ) ? "default" : "static" ) ,
repositoryInformation . getRepositoryInterface ( ) . getName ( ) , method . getName ( ) ) ) ;
}
return ;
}
if ( repositoryInformation . isQueryMethod ( method ) & & contributorFactory ! = null ) {
if ( ! repositoryInformation . isQueryMethod ( method ) ) {
MethodContributor < ? extends QueryMethod > contributor = contributorFactory . create ( method ) ;
if ( contributor ! = null ) {
if ( contributor . contributesMethodSpec ( ) & & ! repositoryInformation . isReactiveRepository ( ) ) {
metadata . addRepositoryMethod ( method , contributor ) ;
} else {
metadata . addDelegateMethod ( method , contributor ) ;
}
if ( logger . isTraceEnabled ( ) ) {
logger . trace ( "Skipping method [%s.%s] contribution, not a query method"
. formatted ( repositoryInformation . getRepositoryInterface ( ) . getName ( ) , method . getName ( ) ) ) ;
}
return ;
}
}
public String packageName ( ) {
return repositoryInformation . getRepositoryInterface ( ) . getPackageName ( ) ;
}
if ( contributorFactory = = null ) {
public String typeName ( ) {
return "%sImpl" . formatted ( repositoryInformation . getRepositoryInterface ( ) . getSimpleName ( ) ) ;
}
if ( logger . isTraceEnabled ( ) ) {
logger . trace ( "Skipping method [%s.%s] contribution, no MethodContributorFactory available"
. formatted ( repositoryInformation . getRepositoryInterface ( ) . getName ( ) , method . getName ( ) ) ) ;
}
return ;
}
public Map < String , ResolvableType > getAutowireFields ( ) {
MethodContributor < ? extends QueryMethod > contributor = contributorFactory . create ( method ) ;
Map < String , ResolvableType > autowireFields = new LinkedHashMap < > (
generationMetadata . getConstructorArguments ( ) . size ( ) ) ;
for ( Map . Entry < String , ConstructorArgument > entry : generationMetadata . getConstructorArguments ( ) . entrySet ( ) ) {
autowireFields . put ( entry . getKey ( ) , entry . getValue ( ) . parameterType ( ) ) ;
if ( contributor = = null ) {
if ( logger . isTraceEnabled ( ) ) {
logger . trace ( "Skipping method [%s.%s] contribution, no MethodContributor available"
. formatted ( repositoryInformation . getRepositoryInterface ( ) . getName ( ) , method . getName ( ) ) ) ;
}
}
return autowireFields ;
}
public RepositoryInformation getRepositoryInformation ( ) {
return repositoryInformation ;
}
public ProjectionFactory getProjectionFactory ( ) {
return projectionFactory ;
if ( contributor . contributesMethodSpec ( ) & & ! repositoryInformation . isReactiveRepository ( ) ) {
generationMetadata . addRepositoryMethod ( method , contributor ) ;
} else {
generationMetadata . addDelegateMethod ( method , contributor ) ;
}
}
/ * *
@ -336,24 +346,11 @@ class AotRepositoryCreator {
@@ -336,24 +346,11 @@ class AotRepositoryCreator {
}
record AotBundle ( Class < ? > sourceRepository , Lazy < JavaFile > javaFile , Lazy < AotRepositoryMetadata > metadata ) {
record AotBundle ( Class < ? > sourceRepository , Lazy < AotRepositoryMetadata > metadata ) {
String repositoryJsonFileName ( ) {
return sourceRepository . getName ( ) . replace ( '.' , '/' ) + ".json" ;
}
TypeReference generatedRepositoryTypeName ( ) {
JavaFile file = javaFile . get ( ) ;
return GeneratedTypeReference . of ( ClassName . get ( file . packageName ( ) , file . typeSpec ( ) . name ( ) ) ) ;
}
String generatedCode ( ) {
return javaFile ( ) . get ( ) . toString ( ) ;
}
String generatedMetadata ( ) {
return metadata ( ) . get ( ) . toJson ( ) . toString ( 2 ) ;
}
}
}