From f1b7952ea57e1a737141ea68ec6e28dc30af71b4 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 14 Jun 2023 09:03:39 +0200 Subject: [PATCH] Use merged bean definitions for EntityCallback type lookup. We now use the merged bean definition to resolve the defined EntityCallback type. Previously, we used just the bean definition that might have contained no type hints because of ASM-parsed configuration classes. Closes #2853 --- .../callback/DefaultEntityCallbacks.java | 7 +- .../callback/EntityCallbackDiscoverer.java | 68 ++++++++----------- .../DefaultEntityCallbacksUnitTests.java | 1 - 3 files changed, 34 insertions(+), 42 deletions(-) diff --git a/src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java b/src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java index b844aa474..d469c0a74 100644 --- a/src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java +++ b/src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java @@ -22,6 +22,7 @@ import java.util.function.BiFunction; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.BeanFactory; +import org.springframework.context.support.GenericApplicationContext; import org.springframework.core.ResolvableType; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -57,7 +58,8 @@ class DefaultEntityCallbacks implements EntityCallbacks { * @param beanFactory must not be {@literal null}. */ DefaultEntityCallbacks(BeanFactory beanFactory) { - this.callbackDiscoverer = new EntityCallbackDiscoverer(beanFactory); + this.callbackDiscoverer = new EntityCallbackDiscoverer( + beanFactory instanceof GenericApplicationContext ac ? ac.getBeanFactory() : beanFactory); } @Override @@ -93,8 +95,7 @@ class DefaultEntityCallbacks implements EntityCallbacks { this.callbackDiscoverer.addEntityCallback(callback); } - static class SimpleEntityCallbackInvoker - implements org.springframework.data.mapping.callback.EntityCallbackInvoker { + static class SimpleEntityCallbackInvoker implements org.springframework.data.mapping.callback.EntityCallbackInvoker { @Override public T invokeCallback(EntityCallback callback, T entity, diff --git a/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java b/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java index f20ddbd5d..91175e77a 100644 --- a/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java +++ b/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java @@ -19,6 +19,7 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.Collection; +import java.util.Comparator; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -28,10 +29,9 @@ import java.util.function.BiFunction; import org.springframework.aop.framework.AopProxyUtils; import org.springframework.beans.factory.BeanFactory; -import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableBeanFactory; -import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.core.ResolvableType; import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.lang.Nullable; @@ -56,7 +56,6 @@ class EntityCallbackDiscoverer { private final Map, ResolvableType> entityTypeCache = new ConcurrentReferenceHashMap<>(64); @Nullable private ClassLoader beanClassLoader; - @Nullable private BeanFactory beanFactory; private Object retrievalMutex = this.defaultRetriever; @@ -104,12 +103,13 @@ class EntityCallbackDiscoverer { * Return a {@link Collection} of all {@link EntityCallback}s matching the given entity type. Non-matching callbacks * get excluded early. * - * @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on - * cached matching information. + * @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on cached + * matching information. * @param callbackType the source callback type. * @return a {@link Collection} of {@link EntityCallback}s. * @see EntityCallback */ + @SuppressWarnings({ "unchecked", "rawtypes" }) Collection> getEntityCallbacks(Class entity, ResolvableType callbackType) { Class sourceType = entity; @@ -121,7 +121,7 @@ class EntityCallbackDiscoverer { return (Collection) retriever.getEntityCallbacks(); } - if (this.beanClassLoader == null || ClassUtils.isCacheSafe(entity.getClass(), this.beanClassLoader) + if (this.beanClassLoader == null || ClassUtils.isCacheSafe(entity, this.beanClassLoader) && (sourceType == null || ClassUtils.isCacheSafe(sourceType, this.beanClassLoader))) { // Fully synchronized building and caching of a CallbackRetriever @@ -163,8 +163,8 @@ class EntityCallbackDiscoverer { * @param retriever the {@link CallbackRetriever}, if supposed to populate one (for caching purposes) * @return the pre-filtered list of entity callbacks for the given entity and callback type. */ - private Collection> retrieveEntityCallbacks(ResolvableType entityType, - ResolvableType callbackType, @Nullable CallbackRetriever retriever) { + private Collection> retrieveEntityCallbacks(ResolvableType entityType, ResolvableType callbackType, + @Nullable CallbackRetriever retriever) { List> allCallbacks = new ArrayList<>(); Set> callbacks; @@ -198,16 +198,14 @@ class EntityCallbackDiscoverer { } /** - * Set the {@link BeanFactory} and optionally {@link #setBeanClassLoader(ClassLoader) class loader} if not set. - * Pre-loads {@link EntityCallback} beans by scanning the {@link BeanFactory}. + * Set the {@link BeanFactory} and optionally class loader if not set. Pre-loads {@link EntityCallback} beans by + * scanning the {@link BeanFactory}. * * @param beanFactory must not be {@literal null}. * @see org.springframework.beans.factory.BeanFactoryAware#setBeanFactory(BeanFactory) */ public void setBeanFactory(BeanFactory beanFactory) { - this.beanFactory = beanFactory; - if (beanFactory instanceof ConfigurableBeanFactory cbf) { if (this.beanClassLoader == null) { @@ -228,10 +226,8 @@ class EntityCallbackDiscoverer { ReflectionUtils.doWithMethods(callbackType, methods::add, method -> { - if (!Modifier.isPublic(method.getModifiers()) - || method.getParameterCount() != args.length + 1 - || method.isBridge() - || ReflectionUtils.isObjectMethod(method)) { + if (!Modifier.isPublic(method.getModifiers()) || method.getParameterCount() != args.length + 1 + || method.isBridge() || ReflectionUtils.isObjectMethod(method)) { return false; } @@ -242,9 +238,8 @@ class EntityCallbackDiscoverer { return methods.iterator().next(); } - throw new IllegalStateException( - "%s does not define a callback method accepting %s and %s additional arguments".formatted( - ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length)); + throw new IllegalStateException("%s does not define a callback method accepting %s and %s additional arguments" + .formatted(ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length)); } static BiFunction, T, Object> computeCallbackInvokerFunction(EntityCallback callback, @@ -267,10 +262,10 @@ class EntityCallbackDiscoverer { * Filter a callback early through checking its generically declared entity type before trying to instantiate it. *

* If this method returns {@literal true} for a given callback as a first pass, the callback instance will get - * retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)} - * call afterwards. + * retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)} call + * afterwards. * - * @param callback the callback's type as determined by the BeanFactory. + * @param callbackType the callback's type as determined by the BeanFactory. * @param entityType the entity type to check. * @return whether the given callback should be included in the candidates for the given callback type. */ @@ -286,11 +281,9 @@ class EntityCallbackDiscoverer { * @param callbackType the source type to check against. * @return whether the given callback should be included in the candidates for the given callback type. */ - static boolean supportsEvent(EntityCallback callback, ResolvableType entityType, - ResolvableType callbackType) { + static boolean supportsEvent(EntityCallback callback, ResolvableType entityType, ResolvableType callbackType) { - return callback instanceof EntityCallbackAdapter provider - ? provider.supports(callbackType, entityType) + return callback instanceof EntityCallbackAdapter provider ? provider.supports(callbackType, entityType) : callbackType.isInstance(callback) && supportsEvent(ResolvableType.forInstance(callback), entityType); } @@ -310,13 +303,11 @@ class EntityCallbackDiscoverer { // We need both a ListableBeanFactory and BeanDefinitionRegistry here for advanced inspection. // If we don't get that, use simple inspection. - if (!(beanFactory instanceof ListableBeanFactory && beanFactory instanceof BeanDefinitionRegistry)) { + if (!(beanFactory instanceof ConfigurableListableBeanFactory bf)) { beanFactory.getBeanProvider(EntityCallback.class).stream().forEach(entityCallbacks::add); return; } - var bf = (ListableBeanFactory & BeanDefinitionRegistry) beanFactory; - for (var beanName : bf.getBeanNamesForType(EntityCallback.class)) { EntityCallback bean = EntityCallback.class.cast(bf.getBean(beanName)); @@ -328,7 +319,7 @@ class EntityCallbackDiscoverer { entityCallbacks.add(bean); } else { - BeanDefinition definition = bf.getBeanDefinition(beanName); + BeanDefinition definition = bf.getMergedBeanDefinition(beanName); entityCallbacks.add(new EntityCallbackAdapter<>(bean, definition.getResolvableType())); } } @@ -340,8 +331,8 @@ class EntityCallbackDiscoverer { * * @author Oliver Drotbohm */ - private static record EntityCallbackAdapter(EntityCallback delegate, ResolvableType type) - implements EntityCallback { + private record EntityCallbackAdapter (EntityCallback delegate, + ResolvableType type) implements EntityCallback { boolean supports(ResolvableType callbackType, ResolvableType entityType) { return callbackType.isInstance(delegate) && supportsEvent(type, entityType); @@ -351,15 +342,16 @@ class EntityCallbackDiscoverer { /** * Cache key for {@link EntityCallback}, based on event type and source type. */ - private static record CallbackCacheKey(ResolvableType callbackType, @Nullable Class entityType) - implements Comparable { + private record CallbackCacheKey(ResolvableType callbackType, + @Nullable Class entityType) implements Comparable { + + private static final Comparator COMPARATOR = Comparators. nullsHigh() // + .thenComparing(it -> it.callbackType.toString()) // + .thenComparing(it -> it.entityType.getName()); @Override public int compareTo(CallbackCacheKey other) { - - return Comparators. nullsHigh() - .thenComparing(it -> callbackType.toString()) - .thenComparing(it -> entityType.getName()).compare(this, other); + return COMPARATOR.compare(this, other); } } diff --git a/src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java b/src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java index 1c54ee2a1..7eaaaf6ed 100644 --- a/src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java +++ b/src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java @@ -170,7 +170,6 @@ class DefaultEntityCallbacksUnitTests { void detectsMultipleCallbacksWithinOneClass() { var ctx = new AnnotationConfigApplicationContext(MultipleCallbacksInOneClassConfig.class); - var callbacks = new DefaultEntityCallbacks(ctx); var personDocument = new PersonDocument(null, "Walter", null);