Browse Source

Use merged bean definitions for EntityCallback type lookup.

We now use the merged bean definition to resolve the defined EntityCallback type.

Previously, we used just the bean definition that might have contained no type hints because of ASM-parsed configuration classes.

Closes #2853
pull/2860/head
Mark Paluch 3 years ago
parent
commit
f1b7952ea5
No known key found for this signature in database
GPG Key ID: 4406B84C1661DCD1
  1. 7
      src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java
  2. 68
      src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java
  3. 1
      src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java

7
src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java

@ -22,6 +22,7 @@ import java.util.function.BiFunction; @@ -22,6 +22,7 @@ import java.util.function.BiFunction;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
@ -57,7 +58,8 @@ class DefaultEntityCallbacks implements EntityCallbacks { @@ -57,7 +58,8 @@ class DefaultEntityCallbacks implements EntityCallbacks {
* @param beanFactory must not be {@literal null}.
*/
DefaultEntityCallbacks(BeanFactory beanFactory) {
this.callbackDiscoverer = new EntityCallbackDiscoverer(beanFactory);
this.callbackDiscoverer = new EntityCallbackDiscoverer(
beanFactory instanceof GenericApplicationContext ac ? ac.getBeanFactory() : beanFactory);
}
@Override
@ -93,8 +95,7 @@ class DefaultEntityCallbacks implements EntityCallbacks { @@ -93,8 +95,7 @@ class DefaultEntityCallbacks implements EntityCallbacks {
this.callbackDiscoverer.addEntityCallback(callback);
}
static class SimpleEntityCallbackInvoker
implements org.springframework.data.mapping.callback.EntityCallbackInvoker {
static class SimpleEntityCallbackInvoker implements org.springframework.data.mapping.callback.EntityCallbackInvoker {
@Override
public <T> T invokeCallback(EntityCallback<T> callback, T entity,

68
src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java

@ -19,6 +19,7 @@ import java.lang.reflect.Method; @@ -19,6 +19,7 @@ import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@ -28,10 +29,9 @@ import java.util.function.BiFunction; @@ -28,10 +29,9 @@ import java.util.function.BiFunction;
import org.springframework.aop.framework.AopProxyUtils;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.lang.Nullable;
@ -56,7 +56,6 @@ class EntityCallbackDiscoverer { @@ -56,7 +56,6 @@ class EntityCallbackDiscoverer {
private final Map<Class<?>, ResolvableType> entityTypeCache = new ConcurrentReferenceHashMap<>(64);
@Nullable private ClassLoader beanClassLoader;
@Nullable private BeanFactory beanFactory;
private Object retrievalMutex = this.defaultRetriever;
@ -104,12 +103,13 @@ class EntityCallbackDiscoverer { @@ -104,12 +103,13 @@ class EntityCallbackDiscoverer {
* Return a {@link Collection} of all {@link EntityCallback}s matching the given entity type. Non-matching callbacks
* get excluded early.
*
* @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on
* cached matching information.
* @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on cached
* matching information.
* @param callbackType the source callback type.
* @return a {@link Collection} of {@link EntityCallback}s.
* @see EntityCallback
*/
@SuppressWarnings({ "unchecked", "rawtypes" })
<T extends S, S> Collection<EntityCallback<S>> getEntityCallbacks(Class<T> entity, ResolvableType callbackType) {
Class<?> sourceType = entity;
@ -121,7 +121,7 @@ class EntityCallbackDiscoverer { @@ -121,7 +121,7 @@ class EntityCallbackDiscoverer {
return (Collection) retriever.getEntityCallbacks();
}
if (this.beanClassLoader == null || ClassUtils.isCacheSafe(entity.getClass(), this.beanClassLoader)
if (this.beanClassLoader == null || ClassUtils.isCacheSafe(entity, this.beanClassLoader)
&& (sourceType == null || ClassUtils.isCacheSafe(sourceType, this.beanClassLoader))) {
// Fully synchronized building and caching of a CallbackRetriever
@ -163,8 +163,8 @@ class EntityCallbackDiscoverer { @@ -163,8 +163,8 @@ class EntityCallbackDiscoverer {
* @param retriever the {@link CallbackRetriever}, if supposed to populate one (for caching purposes)
* @return the pre-filtered list of entity callbacks for the given entity and callback type.
*/
private Collection<EntityCallback<?>> retrieveEntityCallbacks(ResolvableType entityType,
ResolvableType callbackType, @Nullable CallbackRetriever retriever) {
private Collection<EntityCallback<?>> retrieveEntityCallbacks(ResolvableType entityType, ResolvableType callbackType,
@Nullable CallbackRetriever retriever) {
List<EntityCallback<?>> allCallbacks = new ArrayList<>();
Set<EntityCallback<?>> callbacks;
@ -198,16 +198,14 @@ class EntityCallbackDiscoverer { @@ -198,16 +198,14 @@ class EntityCallbackDiscoverer {
}
/**
* Set the {@link BeanFactory} and optionally {@link #setBeanClassLoader(ClassLoader) class loader} if not set.
* Pre-loads {@link EntityCallback} beans by scanning the {@link BeanFactory}.
* Set the {@link BeanFactory} and optionally class loader if not set. Pre-loads {@link EntityCallback} beans by
* scanning the {@link BeanFactory}.
*
* @param beanFactory must not be {@literal null}.
* @see org.springframework.beans.factory.BeanFactoryAware#setBeanFactory(BeanFactory)
*/
public void setBeanFactory(BeanFactory beanFactory) {
this.beanFactory = beanFactory;
if (beanFactory instanceof ConfigurableBeanFactory cbf) {
if (this.beanClassLoader == null) {
@ -228,10 +226,8 @@ class EntityCallbackDiscoverer { @@ -228,10 +226,8 @@ class EntityCallbackDiscoverer {
ReflectionUtils.doWithMethods(callbackType, methods::add, method -> {
if (!Modifier.isPublic(method.getModifiers())
|| method.getParameterCount() != args.length + 1
|| method.isBridge()
|| ReflectionUtils.isObjectMethod(method)) {
if (!Modifier.isPublic(method.getModifiers()) || method.getParameterCount() != args.length + 1
|| method.isBridge() || ReflectionUtils.isObjectMethod(method)) {
return false;
}
@ -242,9 +238,8 @@ class EntityCallbackDiscoverer { @@ -242,9 +238,8 @@ class EntityCallbackDiscoverer {
return methods.iterator().next();
}
throw new IllegalStateException(
"%s does not define a callback method accepting %s and %s additional arguments".formatted(
ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length));
throw new IllegalStateException("%s does not define a callback method accepting %s and %s additional arguments"
.formatted(ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length));
}
static <T> BiFunction<EntityCallback<T>, T, Object> computeCallbackInvokerFunction(EntityCallback<T> callback,
@ -267,10 +262,10 @@ class EntityCallbackDiscoverer { @@ -267,10 +262,10 @@ class EntityCallbackDiscoverer {
* Filter a callback early through checking its generically declared entity type before trying to instantiate it.
* <p>
* If this method returns {@literal true} for a given callback as a first pass, the callback instance will get
* retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)}
* call afterwards.
* retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)} call
* afterwards.
*
* @param callback the callback's type as determined by the BeanFactory.
* @param callbackType the callback's type as determined by the BeanFactory.
* @param entityType the entity type to check.
* @return whether the given callback should be included in the candidates for the given callback type.
*/
@ -286,11 +281,9 @@ class EntityCallbackDiscoverer { @@ -286,11 +281,9 @@ class EntityCallbackDiscoverer {
* @param callbackType the source type to check against.
* @return whether the given callback should be included in the candidates for the given callback type.
*/
static boolean supportsEvent(EntityCallback<?> callback, ResolvableType entityType,
ResolvableType callbackType) {
static boolean supportsEvent(EntityCallback<?> callback, ResolvableType entityType, ResolvableType callbackType) {
return callback instanceof EntityCallbackAdapter<?> provider
? provider.supports(callbackType, entityType)
return callback instanceof EntityCallbackAdapter<?> provider ? provider.supports(callbackType, entityType)
: callbackType.isInstance(callback) && supportsEvent(ResolvableType.forInstance(callback), entityType);
}
@ -310,13 +303,11 @@ class EntityCallbackDiscoverer { @@ -310,13 +303,11 @@ class EntityCallbackDiscoverer {
// We need both a ListableBeanFactory and BeanDefinitionRegistry here for advanced inspection.
// If we don't get that, use simple inspection.
if (!(beanFactory instanceof ListableBeanFactory && beanFactory instanceof BeanDefinitionRegistry)) {
if (!(beanFactory instanceof ConfigurableListableBeanFactory bf)) {
beanFactory.getBeanProvider(EntityCallback.class).stream().forEach(entityCallbacks::add);
return;
}
var bf = (ListableBeanFactory & BeanDefinitionRegistry) beanFactory;
for (var beanName : bf.getBeanNamesForType(EntityCallback.class)) {
EntityCallback<?> bean = EntityCallback.class.cast(bf.getBean(beanName));
@ -328,7 +319,7 @@ class EntityCallbackDiscoverer { @@ -328,7 +319,7 @@ class EntityCallbackDiscoverer {
entityCallbacks.add(bean);
} else {
BeanDefinition definition = bf.getBeanDefinition(beanName);
BeanDefinition definition = bf.getMergedBeanDefinition(beanName);
entityCallbacks.add(new EntityCallbackAdapter<>(bean, definition.getResolvableType()));
}
}
@ -340,8 +331,8 @@ class EntityCallbackDiscoverer { @@ -340,8 +331,8 @@ class EntityCallbackDiscoverer {
*
* @author Oliver Drotbohm
*/
private static record EntityCallbackAdapter<T>(EntityCallback<T> delegate, ResolvableType type)
implements EntityCallback<T> {
private record EntityCallbackAdapter<T> (EntityCallback<T> delegate,
ResolvableType type) implements EntityCallback<T> {
boolean supports(ResolvableType callbackType, ResolvableType entityType) {
return callbackType.isInstance(delegate) && supportsEvent(type, entityType);
@ -351,15 +342,16 @@ class EntityCallbackDiscoverer { @@ -351,15 +342,16 @@ class EntityCallbackDiscoverer {
/**
* Cache key for {@link EntityCallback}, based on event type and source type.
*/
private static record CallbackCacheKey(ResolvableType callbackType, @Nullable Class<?> entityType)
implements Comparable<CallbackCacheKey> {
private record CallbackCacheKey(ResolvableType callbackType,
@Nullable Class<?> entityType) implements Comparable<CallbackCacheKey> {
private static final Comparator<CallbackCacheKey> COMPARATOR = Comparators.<CallbackCacheKey> nullsHigh() //
.thenComparing(it -> it.callbackType.toString()) //
.thenComparing(it -> it.entityType.getName());
@Override
public int compareTo(CallbackCacheKey other) {
return Comparators.<CallbackCacheKey> nullsHigh()
.thenComparing(it -> callbackType.toString())
.thenComparing(it -> entityType.getName()).compare(this, other);
return COMPARATOR.compare(this, other);
}
}

1
src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java

@ -170,7 +170,6 @@ class DefaultEntityCallbacksUnitTests { @@ -170,7 +170,6 @@ class DefaultEntityCallbacksUnitTests {
void detectsMultipleCallbacksWithinOneClass() {
var ctx = new AnnotationConfigApplicationContext(MultipleCallbacksInOneClassConfig.class);
var callbacks = new DefaultEntityCallbacks(ctx);
var personDocument = new PersonDocument(null, "Walter", null);

Loading…
Cancel
Save