diff --git a/src/main/java/org/springframework/data/repository/core/support/EventPublishingRepositoryProxyPostProcessor.java b/src/main/java/org/springframework/data/repository/core/support/EventPublishingRepositoryProxyPostProcessor.java index 053230c7a..9036e6beb 100644 --- a/src/main/java/org/springframework/data/repository/core/support/EventPublishingRepositoryProxyPostProcessor.java +++ b/src/main/java/org/springframework/data/repository/core/support/EventPublishingRepositoryProxyPostProcessor.java @@ -17,6 +17,7 @@ package org.springframework.data.repository.core.support; import java.lang.annotation.Annotation; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Map; @@ -111,7 +112,7 @@ public class EventPublishingRepositoryProxyPostProcessor implements RepositoryPr return result; } - Iterable arguments = asCollection(invocation.getArguments()[0], invocation.getMethod()); + Iterable arguments = asIterable(invocation.getArguments()[0], invocation.getMethod()); eventMethod.publishEventsFrom(arguments, publisher); @@ -144,6 +145,9 @@ public class EventPublishingRepositoryProxyPostProcessor implements RepositoryPr private static Map, EventPublishingMethod> cache = new ConcurrentReferenceHashMap<>(); private static @SuppressWarnings("null") EventPublishingMethod NONE = new EventPublishingMethod(Object.class, null, null); + private static String ILLEGAL_MODIFCATION = "Aggregate's events were modified during event publication. " + + "Make sure event listeners obtain a fresh instance of the aggregate before adding further events. " + + "Additional events found: %s."; private final Class type; private final Method publishingMethod; @@ -188,7 +192,11 @@ public class EventPublishingRepositoryProxyPostProcessor implements RepositoryPr * @param aggregates can be {@literal null}. * @param publisher must not be {@literal null}. */ - public void publishEventsFrom(Iterable aggregates, ApplicationEventPublisher publisher) { + public void publishEventsFrom(@Nullable Iterable aggregates, ApplicationEventPublisher publisher) { + + if (aggregates == null) { + return; + } for (Object aggregateRoot : aggregates) { @@ -196,10 +204,21 @@ public class EventPublishingRepositoryProxyPostProcessor implements RepositoryPr continue; } - for (Object event : asCollection(ReflectionUtils.invokeMethod(publishingMethod, aggregateRoot), null)) { + var events = asCollection(ReflectionUtils.invokeMethod(publishingMethod, aggregateRoot)); + + for (Object event : events) { publisher.publishEvent(event); } + var postPublication = asCollection(ReflectionUtils.invokeMethod(publishingMethod, aggregateRoot)); + + if (events.size() != postPublication.size()) { + + postPublication.removeAll(events); + + throw new IllegalStateException(ILLEGAL_MODIFCATION.formatted(postPublication)); + } + if (clearingMethod != null) { ReflectionUtils.invokeMethod(clearingMethod, aggregateRoot); } @@ -272,23 +291,34 @@ public class EventPublishingRepositoryProxyPostProcessor implements RepositoryPr * one-element collection, {@literal null} will become an empty collection. * * @param source can be {@literal null}. - * @return + * @return will never be {@literal null}. */ @SuppressWarnings("unchecked") - private static Iterable asCollection(@Nullable Object source, @Nullable Method method) { + private static Collection asCollection(@Nullable Object source) { if (source == null) { return Collections.emptyList(); } - if (method != null && method.getName().startsWith("saveAll")) { - return (Iterable) source; - } - if (Collection.class.isInstance(source)) { - return (Collection) source; + return new ArrayList<>((Collection) source); } return Collections.singletonList(source); } + + /** + * Returns the given source object as {@link Iterable}. + * + * @param source can be {@literal null}. + * @return will never be {@literal null}. + */ + @Nullable + @SuppressWarnings("unchecked") + private static Iterable asIterable(@Nullable Object source, @Nullable Method method) { + + return method != null && method.getName().startsWith("saveAll") + ? (Iterable) source + : asCollection(source); + } } diff --git a/src/test/java/org/springframework/data/repository/core/support/EventPublishingRepositoryProxyPostProcessorUnitTests.java b/src/test/java/org/springframework/data/repository/core/support/EventPublishingRepositoryProxyPostProcessorUnitTests.java index 9c5e1efa9..30fbc4bfb 100644 --- a/src/test/java/org/springframework/data/repository/core/support/EventPublishingRepositoryProxyPostProcessorUnitTests.java +++ b/src/test/java/org/springframework/data/repository/core/support/EventPublishingRepositoryProxyPostProcessorUnitTests.java @@ -20,6 +20,7 @@ import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.*; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -324,6 +325,32 @@ class EventPublishingRepositoryProxyPostProcessorUnitTests { verify(publisher, never()).publishEvent(any()); } + @Test // GH-3116 + void rejectsEventAddedDuringProcessing() throws Throwable { + + var originalEvent = new SomeEvent(); + var eventToBeAdded = new SomeEvent(); + + var events = new ArrayList(); + events.add(originalEvent); + + var aggregate = MultipleEvents.of(events); + + doAnswer(invocation -> { + + events.add(eventToBeAdded); + return null; + + }).when(publisher).publishEvent(any(Object.class)); + + var method = EventPublishingMethod.of(MultipleEvents.class); + + assertThatIllegalStateException() + .isThrownBy(() -> method.publishEventsFrom(List.of(aggregate), publisher)) + .withMessageContaining(eventToBeAdded.toString()) + .withMessageNotContaining(originalEvent.toString()); + } + private static void mockInvocation(MethodInvocation invocation, Method method, Object parameterAndReturnValue) throws Throwable {