@ -16,6 +16,7 @@
@@ -16,6 +16,7 @@
package org.springframework.data.aot ;
import java.util.Collections ;
import java.util.function.Function ;
import java.util.function.Supplier ;
import org.apache.commons.logging.Log ;
@ -49,6 +50,15 @@ public class SpringDataBeanFactoryInitializationAotProcessor implements BeanFact
@@ -49,6 +50,15 @@ public class SpringDataBeanFactoryInitializationAotProcessor implements BeanFact
private static final Log logger = LogFactory . getLog ( BeanFactoryInitializationAotProcessor . class ) ;
private static final Function < Object , Object > arrayToListFunction = target - >
ObjectUtils . isArray ( target ) ? CollectionUtils . arrayToList ( target ) : target ;
private static final Function < Object , Object > asSingletonSetFunction = target - >
! ( target instanceof Iterable < ? > ) ? Collections . singleton ( target ) : target ;
private static final Function < Object , Object > constructorArgumentFunction =
arrayToListFunction . andThen ( asSingletonSetFunction ) ;
@Nullable
@Override
public BeanFactoryInitializationAotContribution processAheadOfTime ( ConfigurableListableBeanFactory beanFactory ) {
@ -64,33 +74,38 @@ public class SpringDataBeanFactoryInitializationAotProcessor implements BeanFact
@@ -64,33 +74,38 @@ public class SpringDataBeanFactoryInitializationAotProcessor implements BeanFact
BeanDefinition beanDefinition = beanFactory . getBeanDefinition ( beanName ) ;
if ( beanDefinition . getConstructorArgumentValues ( ) . isEmpty ( ) ) {
return ;
}
if ( hasConstructorArguments ( beanDefinition ) ) {
ValueHolder argumentValue = beanDefinition . getConstructorArgumentValues ( ) . getArgumentValue ( 0 , null , null , null ) ;
ValueHolder argumentValue = beanDefinition . getConstructorArgumentValues ( )
. getArgumentValue ( 0 , null , null , null ) ;
if ( argumentValue . getValue ( ) instanceof Supplier supplier ) {
if ( argumentValue . getValue ( ) instanceof Supplier supplier ) {
if ( logger . isDebugEnabled ( ) ) {
logger . info ( String . format ( "Replacing ManagedType bean definition %s." , beanName ) ) ;
}
if ( logger . isDebugEnabled ( ) ) {
logger . info ( String . format ( "Replacing ManagedType bean definition %s." , beanName ) ) ;
}
Object value = supplier . get ( ) ;
if ( ObjectUtils . isArray ( value ) ) {
value = CollectionUtils . arrayToList ( value ) ;
}
if ( ! ( value instanceof Iterable < ? > ) ) {
value = Collections . singleton ( value ) ;
}
Object value = constructorArgumentFunction . apply ( supplier . get ( ) ) ;
BeanDefinition beanDefinitionReplacement = BeanDefinitionBuilder . rootBeanDefinition ( ManagedTypes . class )
. setFactoryMethod ( "fromIterable" ) . addConstructorArgValue ( value ) . getBeanDefinition ( ) ;
BeanDefinition beanDefinitionReplacement = newManagedTypeBeanDefinition ( value ) ;
registry . removeBeanDefinition ( beanName ) ;
registry . registerBeanDefinition ( beanName , beanDefinitionReplacement ) ;
registry . removeBeanDefinition ( beanName ) ;
registry . registerBeanDefinition ( beanName , beanDefinitionReplacement ) ;
}
}
}
}
}
private boolean hasConstructorArguments ( BeanDefinition beanDefinition ) {
return ! beanDefinition . getConstructorArgumentValues ( ) . isEmpty ( ) ;
}
private BeanDefinition newManagedTypeBeanDefinition ( Object constructorArgumentValue ) {
return BeanDefinitionBuilder . rootBeanDefinition ( ManagedTypes . class )
. setFactoryMethod ( "fromIterable" )
. addConstructorArgValue ( constructorArgumentValue )
. getBeanDefinition ( ) ;
}
}