diff --git a/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java b/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java index 04e89192899..38684155568 100644 --- a/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java +++ b/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java @@ -19,10 +19,12 @@ package org.springframework.core.task; import java.io.Serializable; import java.util.Set; import java.util.concurrent.Callable; +import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicBoolean; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -96,9 +98,9 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator private boolean rejectTasksWhenLimitReached = false; - private volatile boolean active = true; + private final AtomicBoolean closed = new AtomicBoolean(); - private volatile boolean cancelled = false; + private boolean cancelled = false; // within activeThreads synchronization /** @@ -274,7 +276,7 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator * @see #close() */ public boolean isActive() { - return this.active; + return !this.closed.get(); } /** @@ -312,14 +314,15 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator public void execute(Runnable task, long startTimeout) { Assert.notNull(task, "Runnable must not be null"); if (!isActive()) { - throw new TaskRejectedException(getClass().getSimpleName() + " has been closed already"); + throw new TaskRejectedException(getClass().getSimpleName() + " is not active"); } Runnable taskToUse = (this.taskDecorator != null ? this.taskDecorator.decorate(task) : task); + Future future = (task instanceof Future f ? f : null); if (isThrottleActive() && startTimeout > TIMEOUT_IMMEDIATE) { this.concurrencyThrottle.beforeAccess(); try { - doExecute(new TaskTrackingRunnable(taskToUse)); + doExecute(new TaskTrackingRunnable(taskToUse, future)); } catch (Throwable ex) { // Release concurrency permit if thread creation fails @@ -329,7 +332,7 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator } } else if (this.activeThreads != null) { - doExecute(new TaskTrackingRunnable(taskToUse)); + doExecute(new TaskTrackingRunnable(taskToUse, future)); } else { doExecute(taskToUse); @@ -405,12 +408,13 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator */ @Override public void close() { - if (this.active) { - this.active = false; + if (this.closed.compareAndSet(false, true)) { Set threads = this.activeThreads; if (threads != null) { if (this.cancelRemainingTasksOnClose) { - this.cancelled = true; + synchronized (threads) { + this.cancelled = true; + } // Early interrupt for remaining tasks on close threads.forEach(Thread::interrupt); } @@ -435,9 +439,12 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator } } - private void checkCancelled() { - if (this.cancelled) { - throw new TaskRejectedException(getClass().getSimpleName() + " has cancelled all remaining tasks"); + private void checkCancelled(@Nullable Future future) { + if (this.cancelled) { // within synchronization from TaskTrackingRunnable + if (future != null) { + future.cancel(false); + } + throw new CancellationException(getClass().getSimpleName() + " has cancelled all remaining tasks"); } } @@ -477,9 +484,12 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator private final Runnable task; - public TaskTrackingRunnable(Runnable task) { + private final @Nullable Future future; + + public TaskTrackingRunnable(Runnable task, @Nullable Future future) { Assert.notNull(task, "Task must not be null"); this.task = task; + this.future = future; } @Override @@ -488,27 +498,19 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator Thread thread = null; if (threads != null) { thread = Thread.currentThread(); - if (isActive()) { + synchronized (threads) { + checkCancelled(this.future); threads.add(thread); } - else { - synchronized (threads) { - checkCancelled(); - threads.add(thread); - } - } } try { this.task.run(); } finally { if (threads != null) { - if (isActive()) { - threads.remove(thread); - } - else { + threads.remove(thread); + if (closed.get()) { synchronized (threads) { - threads.remove(thread); if (threads.isEmpty()) { threads.notify(); } diff --git a/spring-core/src/test/java/org/springframework/core/task/SimpleAsyncTaskExecutorTests.java b/spring-core/src/test/java/org/springframework/core/task/SimpleAsyncTaskExecutorTests.java index e53773ee074..68a67f5d733 100644 --- a/spring-core/src/test/java/org/springframework/core/task/SimpleAsyncTaskExecutorTests.java +++ b/spring-core/src/test/java/org/springframework/core/task/SimpleAsyncTaskExecutorTests.java @@ -16,8 +16,11 @@ package org.springframework.core.task; +import java.util.concurrent.CancellationException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.jupiter.api.Test; @@ -27,6 +30,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.willCallRealMethod; @@ -86,11 +90,11 @@ class SimpleAsyncTaskExecutorTests { *

This test reproduces a critical bug where OutOfMemoryError from * Thread.start() causes the executor to permanently deadlock: *

    - *
  1. beforeAccess() increments concurrencyCount - *
  2. doExecute() throws Error before thread starts - *
  3. TaskTrackingRunnable.run() never executes - *
  4. afterAccess() in finally block never called - *
  5. Subsequent tasks block forever in onLimitReached() + *
  6. beforeAccess() increments concurrencyCount + *
  7. doExecute() throws Error before thread starts + *
  8. TaskTrackingRunnable.run() never executes + *
  9. afterAccess() in finally block never called + *
  10. Subsequent tasks block forever in onLimitReached() *
* *

Test approach: The first execute() should fail with some exception @@ -131,6 +135,105 @@ class SimpleAsyncTaskExecutorTests { .isTrue(); } + @Test + void taskTerminationTimeout() throws InterruptedException{ + Future future; + try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor()) { + executor.setTaskTerminationTimeout(500); + future = executor.submit(() -> { + try { + Thread.sleep(200); + } + catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new IllegalStateException(); + } + }); + Thread.sleep(100); + } + assertThatNoException().isThrownBy(future::get); + } + + @Test + void taskTerminationTimeoutWithImmediateCancel() { + AtomicBoolean finished = new AtomicBoolean(); + Future future; + try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor()) { + executor.setTaskTerminationTimeout(100); + future = executor.submit(() -> { + if (finished.get()) { + throw new IllegalStateException(); + } + }); + } + finished.set(true); + assertThatExceptionOfType(CancellationException.class).isThrownBy(future::get); + } + + @Test + void taskTerminationTimeoutWithLateInterrupt() throws InterruptedException { + AtomicBoolean interrupted = new AtomicBoolean(); + Future future; + try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor()) { + executor.setTaskTerminationTimeout(200); + future = executor.submit(() -> { + try { + Thread.sleep(500); + } + catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + interrupted.set(true); + } + }); + Thread.sleep(100); + } + assertThatNoException().isThrownBy(future::get); + assertThat(interrupted).isTrue(); + } + + @Test + void taskTerminationTimeoutWithEarlyInterrupt() throws InterruptedException { + AtomicBoolean interrupted = new AtomicBoolean(); + Future future; + try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor()) { + executor.setTaskTerminationTimeout(500); + executor.setCancelRemainingTasksOnClose(true); + future = executor.submit(() -> { + try { + Thread.sleep(200); + } + catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + interrupted.set(true); + } + }); + Thread.sleep(100); + } + assertThatNoException().isThrownBy(future::get); + assertThat(interrupted).isTrue(); + } + + @Test + void cancelRemainingTasksOnClose() throws InterruptedException { + AtomicBoolean interrupted = new AtomicBoolean(); + Future future; + try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor()) { + executor.setCancelRemainingTasksOnClose(true); + future = executor.submit(() -> { + try { + Thread.sleep(200); + } + catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + interrupted.set(true); + } + }); + Thread.sleep(100); + } + assertThatNoException().isThrownBy(future::get); + assertThat(interrupted).isTrue(); + } + @Test void threadNameGetsSetCorrectly() { String customPrefix = "chankPop#";