diff --git a/src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java b/src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java index b844aa474..d469c0a74 100644 --- a/src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java +++ b/src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java @@ -22,6 +22,7 @@ import java.util.function.BiFunction; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.BeanFactory; +import org.springframework.context.support.GenericApplicationContext; import org.springframework.core.ResolvableType; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -57,7 +58,8 @@ class DefaultEntityCallbacks implements EntityCallbacks { * @param beanFactory must not be {@literal null}. */ DefaultEntityCallbacks(BeanFactory beanFactory) { - this.callbackDiscoverer = new EntityCallbackDiscoverer(beanFactory); + this.callbackDiscoverer = new EntityCallbackDiscoverer( + beanFactory instanceof GenericApplicationContext ac ? ac.getBeanFactory() : beanFactory); } @Override @@ -93,8 +95,7 @@ class DefaultEntityCallbacks implements EntityCallbacks { this.callbackDiscoverer.addEntityCallback(callback); } - static class SimpleEntityCallbackInvoker - implements org.springframework.data.mapping.callback.EntityCallbackInvoker { + static class SimpleEntityCallbackInvoker implements org.springframework.data.mapping.callback.EntityCallbackInvoker { @Override public T invokeCallback(EntityCallback callback, T entity, diff --git a/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java b/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java index f20ddbd5d..91175e77a 100644 --- a/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java +++ b/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java @@ -19,6 +19,7 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.Collection; +import java.util.Comparator; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -28,10 +29,9 @@ import java.util.function.BiFunction; import org.springframework.aop.framework.AopProxyUtils; import org.springframework.beans.factory.BeanFactory; -import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableBeanFactory; -import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.core.ResolvableType; import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.lang.Nullable; @@ -56,7 +56,6 @@ class EntityCallbackDiscoverer { private final Map, ResolvableType> entityTypeCache = new ConcurrentReferenceHashMap<>(64); @Nullable private ClassLoader beanClassLoader; - @Nullable private BeanFactory beanFactory; private Object retrievalMutex = this.defaultRetriever; @@ -104,12 +103,13 @@ class EntityCallbackDiscoverer { * Return a {@link Collection} of all {@link EntityCallback}s matching the given entity type. Non-matching callbacks * get excluded early. * - * @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on - * cached matching information. + * @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on cached + * matching information. * @param callbackType the source callback type. * @return a {@link Collection} of {@link EntityCallback}s. * @see EntityCallback */ + @SuppressWarnings({ "unchecked", "rawtypes" }) Collection> getEntityCallbacks(Class entity, ResolvableType callbackType) { Class sourceType = entity; @@ -121,7 +121,7 @@ class EntityCallbackDiscoverer { return (Collection) retriever.getEntityCallbacks(); } - if (this.beanClassLoader == null || ClassUtils.isCacheSafe(entity.getClass(), this.beanClassLoader) + if (this.beanClassLoader == null || ClassUtils.isCacheSafe(entity, this.beanClassLoader) && (sourceType == null || ClassUtils.isCacheSafe(sourceType, this.beanClassLoader))) { // Fully synchronized building and caching of a CallbackRetriever @@ -163,8 +163,8 @@ class EntityCallbackDiscoverer { * @param retriever the {@link CallbackRetriever}, if supposed to populate one (for caching purposes) * @return the pre-filtered list of entity callbacks for the given entity and callback type. */ - private Collection> retrieveEntityCallbacks(ResolvableType entityType, - ResolvableType callbackType, @Nullable CallbackRetriever retriever) { + private Collection> retrieveEntityCallbacks(ResolvableType entityType, ResolvableType callbackType, + @Nullable CallbackRetriever retriever) { List> allCallbacks = new ArrayList<>(); Set> callbacks; @@ -198,16 +198,14 @@ class EntityCallbackDiscoverer { } /** - * Set the {@link BeanFactory} and optionally {@link #setBeanClassLoader(ClassLoader) class loader} if not set. - * Pre-loads {@link EntityCallback} beans by scanning the {@link BeanFactory}. + * Set the {@link BeanFactory} and optionally class loader if not set. Pre-loads {@link EntityCallback} beans by + * scanning the {@link BeanFactory}. * * @param beanFactory must not be {@literal null}. * @see org.springframework.beans.factory.BeanFactoryAware#setBeanFactory(BeanFactory) */ public void setBeanFactory(BeanFactory beanFactory) { - this.beanFactory = beanFactory; - if (beanFactory instanceof ConfigurableBeanFactory cbf) { if (this.beanClassLoader == null) { @@ -228,10 +226,8 @@ class EntityCallbackDiscoverer { ReflectionUtils.doWithMethods(callbackType, methods::add, method -> { - if (!Modifier.isPublic(method.getModifiers()) - || method.getParameterCount() != args.length + 1 - || method.isBridge() - || ReflectionUtils.isObjectMethod(method)) { + if (!Modifier.isPublic(method.getModifiers()) || method.getParameterCount() != args.length + 1 + || method.isBridge() || ReflectionUtils.isObjectMethod(method)) { return false; } @@ -242,9 +238,8 @@ class EntityCallbackDiscoverer { return methods.iterator().next(); } - throw new IllegalStateException( - "%s does not define a callback method accepting %s and %s additional arguments".formatted( - ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length)); + throw new IllegalStateException("%s does not define a callback method accepting %s and %s additional arguments" + .formatted(ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length)); } static BiFunction, T, Object> computeCallbackInvokerFunction(EntityCallback callback, @@ -267,10 +262,10 @@ class EntityCallbackDiscoverer { * Filter a callback early through checking its generically declared entity type before trying to instantiate it. *

* If this method returns {@literal true} for a given callback as a first pass, the callback instance will get - * retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)} - * call afterwards. + * retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)} call + * afterwards. * - * @param callback the callback's type as determined by the BeanFactory. + * @param callbackType the callback's type as determined by the BeanFactory. * @param entityType the entity type to check. * @return whether the given callback should be included in the candidates for the given callback type. */ @@ -286,11 +281,9 @@ class EntityCallbackDiscoverer { * @param callbackType the source type to check against. * @return whether the given callback should be included in the candidates for the given callback type. */ - static boolean supportsEvent(EntityCallback callback, ResolvableType entityType, - ResolvableType callbackType) { + static boolean supportsEvent(EntityCallback callback, ResolvableType entityType, ResolvableType callbackType) { - return callback instanceof EntityCallbackAdapter provider - ? provider.supports(callbackType, entityType) + return callback instanceof EntityCallbackAdapter provider ? provider.supports(callbackType, entityType) : callbackType.isInstance(callback) && supportsEvent(ResolvableType.forInstance(callback), entityType); } @@ -310,13 +303,11 @@ class EntityCallbackDiscoverer { // We need both a ListableBeanFactory and BeanDefinitionRegistry here for advanced inspection. // If we don't get that, use simple inspection. - if (!(beanFactory instanceof ListableBeanFactory && beanFactory instanceof BeanDefinitionRegistry)) { + if (!(beanFactory instanceof ConfigurableListableBeanFactory bf)) { beanFactory.getBeanProvider(EntityCallback.class).stream().forEach(entityCallbacks::add); return; } - var bf = (ListableBeanFactory & BeanDefinitionRegistry) beanFactory; - for (var beanName : bf.getBeanNamesForType(EntityCallback.class)) { EntityCallback bean = EntityCallback.class.cast(bf.getBean(beanName)); @@ -328,7 +319,7 @@ class EntityCallbackDiscoverer { entityCallbacks.add(bean); } else { - BeanDefinition definition = bf.getBeanDefinition(beanName); + BeanDefinition definition = bf.getMergedBeanDefinition(beanName); entityCallbacks.add(new EntityCallbackAdapter<>(bean, definition.getResolvableType())); } } @@ -340,8 +331,8 @@ class EntityCallbackDiscoverer { * * @author Oliver Drotbohm */ - private static record EntityCallbackAdapter(EntityCallback delegate, ResolvableType type) - implements EntityCallback { + private record EntityCallbackAdapter (EntityCallback delegate, + ResolvableType type) implements EntityCallback { boolean supports(ResolvableType callbackType, ResolvableType entityType) { return callbackType.isInstance(delegate) && supportsEvent(type, entityType); @@ -351,15 +342,16 @@ class EntityCallbackDiscoverer { /** * Cache key for {@link EntityCallback}, based on event type and source type. */ - private static record CallbackCacheKey(ResolvableType callbackType, @Nullable Class entityType) - implements Comparable { + private record CallbackCacheKey(ResolvableType callbackType, + @Nullable Class entityType) implements Comparable { + + private static final Comparator COMPARATOR = Comparators. nullsHigh() // + .thenComparing(it -> it.callbackType.toString()) // + .thenComparing(it -> it.entityType.getName()); @Override public int compareTo(CallbackCacheKey other) { - - return Comparators. nullsHigh() - .thenComparing(it -> callbackType.toString()) - .thenComparing(it -> entityType.getName()).compare(this, other); + return COMPARATOR.compare(this, other); } } diff --git a/src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java b/src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java index 1c54ee2a1..7eaaaf6ed 100644 --- a/src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java +++ b/src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java @@ -170,7 +170,6 @@ class DefaultEntityCallbacksUnitTests { void detectsMultipleCallbacksWithinOneClass() { var ctx = new AnnotationConfigApplicationContext(MultipleCallbacksInOneClassConfig.class); - var callbacks = new DefaultEntityCallbacks(ctx); var personDocument = new PersonDocument(null, "Walter", null);