diff --git a/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java b/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java index 6ac23fd09..d92233eca 100644 --- a/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java +++ b/src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java @@ -19,6 +19,7 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -186,7 +187,7 @@ class EntityCallbackDiscoverer { private Collection> retrieveEntityCallbacks(ResolvableType entityType, ResolvableType callbackType, @Nullable CallbackRetriever retriever) { - List> allCallbacks = new ArrayList<>(); + List> allCallbacks = null; Set> callbacks; Set callbackBeans; @@ -197,8 +198,9 @@ class EntityCallbackDiscoverer { for (EntityCallback callback : callbacks) { if (supportsEvent(callback, entityType, callbackType)) { - if (retriever != null) { - retriever.getEntityCallbacks().add(callback); + + if (allCallbacks == null) { + allCallbacks = new ArrayList<>(); } allCallbacks.add(callback); } @@ -211,7 +213,9 @@ class EntityCallbackDiscoverer { Class callbackImplType = beanFactory.getType(callbackBeanName); if (callbackImplType == null || supportsEvent(callbackImplType, entityType)) { EntityCallback callback = beanFactory.getBean(callbackBeanName, EntityCallback.class); - if (!allCallbacks.contains(callback) && supportsEvent(callback, entityType, callbackType)) { + + if ((allCallbacks == null || !allCallbacks.contains(callback)) + && supportsEvent(callback, entityType, callbackType)) { if (retriever != null) { if (beanFactory.isSingleton(callbackBeanName)) { retriever.entityCallbacks.add(callback); @@ -219,6 +223,10 @@ class EntityCallbackDiscoverer { retriever.entityCallbackBeans.add(callbackBeanName); } } + + if (allCallbacks == null) { + allCallbacks = new ArrayList<>(); + } allCallbacks.add(callback); } } @@ -229,6 +237,10 @@ class EntityCallbackDiscoverer { } } + if (allCallbacks == null) { + return Collections.emptyList(); + } + AnnotationAwareOrderComparator.sort(allCallbacks); if (retriever != null && retriever.entityCallbackBeans.isEmpty()) { @@ -272,7 +284,7 @@ class EntityCallbackDiscoverer { /** * (non-Javadoc) - * + * * @see org.springframework.beans.factory.BeanClassLoaderAware */ public void setBeanClassLoader(ClassLoader classLoader) { @@ -356,6 +368,8 @@ class EntityCallbackDiscoverer { private final Set> entityCallbacks = new LinkedHashSet<>(); + private final List> cachedEntityCallbacks = new ArrayList<>(); + private final Set entityCallbackBeans = new LinkedHashSet<>(); private final boolean preFiltered; @@ -366,6 +380,17 @@ class EntityCallbackDiscoverer { Collection> getEntityCallbacks() { + if (this.entityCallbackBeans.isEmpty()) { + + if (cachedEntityCallbacks.size() != entityCallbacks.size()) { + cachedEntityCallbacks.clear(); + cachedEntityCallbacks.addAll(entityCallbacks); + AnnotationAwareOrderComparator.sort(cachedEntityCallbacks); + } + + return cachedEntityCallbacks; + } + List> allCallbacks = new ArrayList<>( this.entityCallbacks.size() + this.entityCallbackBeans.size()); allCallbacks.addAll(this.entityCallbacks);