diff --git a/src/main/java/org/springframework/data/aot/SpringDataBeanFactoryInitializationAotProcessor.java b/src/main/java/org/springframework/data/aot/SpringDataBeanFactoryInitializationAotProcessor.java index 8227a4c75..21a082e7d 100644 --- a/src/main/java/org/springframework/data/aot/SpringDataBeanFactoryInitializationAotProcessor.java +++ b/src/main/java/org/springframework/data/aot/SpringDataBeanFactoryInitializationAotProcessor.java @@ -16,6 +16,7 @@ package org.springframework.data.aot; import java.util.Collections; +import java.util.function.Function; import java.util.function.Supplier; import org.apache.commons.logging.Log; @@ -49,6 +50,15 @@ public class SpringDataBeanFactoryInitializationAotProcessor implements BeanFact private static final Log logger = LogFactory.getLog(BeanFactoryInitializationAotProcessor.class); + private static final Function arrayToListFunction = target -> + ObjectUtils.isArray(target) ? CollectionUtils.arrayToList(target) : target; + + private static final Function asSingletonSetFunction = target -> + !(target instanceof Iterable) ? Collections.singleton(target) : target; + + private static final Function constructorArgumentFunction = + arrayToListFunction.andThen(asSingletonSetFunction); + @Nullable @Override public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) { @@ -64,33 +74,38 @@ public class SpringDataBeanFactoryInitializationAotProcessor implements BeanFact BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName); - if (beanDefinition.getConstructorArgumentValues().isEmpty()) { - return; - } + if (hasConstructorArguments(beanDefinition)) { - ValueHolder argumentValue = beanDefinition.getConstructorArgumentValues().getArgumentValue(0, null, null, null); + ValueHolder argumentValue = beanDefinition.getConstructorArgumentValues() + .getArgumentValue(0, null, null, null); - if (argumentValue.getValue()instanceof Supplier supplier) { + if (argumentValue.getValue() instanceof Supplier supplier) { - if (logger.isDebugEnabled()) { - logger.info(String.format("Replacing ManagedType bean definition %s.", beanName)); - } + if (logger.isDebugEnabled()) { + logger.info(String.format("Replacing ManagedType bean definition %s.", beanName)); + } - Object value = supplier.get(); - if (ObjectUtils.isArray(value)) { - value = CollectionUtils.arrayToList(value); - } - if (!(value instanceof Iterable)) { - value = Collections.singleton(value); - } + Object value = constructorArgumentFunction.apply(supplier.get()); - BeanDefinition beanDefinitionReplacement = BeanDefinitionBuilder.rootBeanDefinition(ManagedTypes.class) - .setFactoryMethod("fromIterable").addConstructorArgValue(value).getBeanDefinition(); + BeanDefinition beanDefinitionReplacement = newManagedTypeBeanDefinition(value); - registry.removeBeanDefinition(beanName); - registry.registerBeanDefinition(beanName, beanDefinitionReplacement); + registry.removeBeanDefinition(beanName); + registry.registerBeanDefinition(beanName, beanDefinitionReplacement); + } } } } } + + private boolean hasConstructorArguments(BeanDefinition beanDefinition) { + return !beanDefinition.getConstructorArgumentValues().isEmpty(); + } + + private BeanDefinition newManagedTypeBeanDefinition(Object constructorArgumentValue) { + + return BeanDefinitionBuilder.rootBeanDefinition(ManagedTypes.class) + .setFactoryMethod("fromIterable") + .addConstructorArgValue(constructorArgumentValue) + .getBeanDefinition(); + } }