diff --git a/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java b/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java index 54e47b8e7e7..db208a6975c 100644 --- a/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java +++ b/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java @@ -98,6 +98,8 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator private volatile boolean active = true; + private volatile boolean cancelled = false; + /** * Create a new SimpleAsyncTaskExecutor with default thread name prefix. @@ -406,6 +408,7 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator Set threads = this.activeThreads; if (threads != null) { if (this.cancelRemainingTasksOnClose) { + this.cancelled = true; // Early interrupt for remaining tasks on close threads.forEach(Thread::interrupt); } @@ -419,6 +422,7 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator catch (InterruptedException ex) { Thread.currentThread().interrupt(); } + this.cancelled = true; } if (!this.cancelRemainingTasksOnClose) { // Late interrupt for remaining tasks after timeout @@ -429,6 +433,12 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator } } + private void checkCancelled() { + if (this.cancelled) { + throw new TaskRejectedException(getClass().getSimpleName() + " has cancelled all remaining tasks"); + } + } + /** * Subclass of the general ConcurrencyThrottleSupport class, @@ -476,16 +486,27 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator Thread thread = null; if (threads != null) { thread = Thread.currentThread(); - threads.add(thread); + if (isActive()) { + threads.add(thread); + } + else { + synchronized (threads) { + checkCancelled(); + threads.add(thread); + } + } } try { this.task.run(); } finally { if (threads != null) { - threads.remove(thread); - if (!isActive()) { + if (isActive()) { + threads.remove(thread); + } + else { synchronized (threads) { + threads.remove(thread); if (threads.isEmpty()) { threads.notify(); }