From 8b36736344a333c1bddaccbc46a0367d28bed9f9 Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Sat, 11 Oct 2025 13:25:30 +0200 Subject: [PATCH] Add concurrency throttle and flexible task callback to SyncTaskExecutor Closes gh-35460 --- .../annotation/ConcurrencyLimit.java | 7 +- .../resilience/ConcurrencyLimitTests.java | 10 +- .../core/task/SyncTaskExecutor.java | 54 ++++++-- .../core/task/TaskCallback.java | 42 ++++++ .../core/task/SyncTaskExecutorTests.java | 126 ++++++++++++++++++ 5 files changed, 225 insertions(+), 14 deletions(-) create mode 100644 spring-core/src/main/java/org/springframework/core/task/TaskCallback.java create mode 100644 spring-core/src/test/java/org/springframework/core/task/SyncTaskExecutorTests.java diff --git a/spring-context/src/main/java/org/springframework/resilience/annotation/ConcurrencyLimit.java b/spring-context/src/main/java/org/springframework/resilience/annotation/ConcurrencyLimit.java index 3e6ec382bb0..50bdb09456e 100644 --- a/spring-context/src/main/java/org/springframework/resilience/annotation/ConcurrencyLimit.java +++ b/spring-context/src/main/java/org/springframework/resilience/annotation/ConcurrencyLimit.java @@ -38,9 +38,11 @@ import org.springframework.core.annotation.AliasFor; * *

This is particularly useful with Virtual Threads where there is generally * no thread pool limit in place. For asynchronous tasks, this can be constrained - * on {@link org.springframework.core.task.SimpleAsyncTaskExecutor}; for + * on {@link org.springframework.core.task.SimpleAsyncTaskExecutor}. For * synchronous invocations, this annotation provides equivalent behavior through - * {@link org.springframework.aop.interceptor.ConcurrencyThrottleInterceptor}. + * {@link org.springframework.aop.interceptor.ConcurrencyThrottleInterceptor} + * Alternatively, consider {@link org.springframework.core.task.SyncTaskExecutor} + * and its inherited concurrency throttle (new as of 7.0) for programmatic use. * * @author Juergen Hoeller * @author Hyunsang Han @@ -49,6 +51,7 @@ import org.springframework.core.annotation.AliasFor; * @see EnableResilientMethods * @see ConcurrencyLimitBeanPostProcessor * @see org.springframework.aop.interceptor.ConcurrencyThrottleInterceptor + * @see org.springframework.core.task.SyncTaskExecutor#setConcurrencyLimit * @see org.springframework.core.task.SimpleAsyncTaskExecutor#setConcurrencyLimit */ @Target({ElementType.TYPE, ElementType.METHOD}) diff --git a/spring-context/src/test/java/org/springframework/resilience/ConcurrencyLimitTests.java b/spring-context/src/test/java/org/springframework/resilience/ConcurrencyLimitTests.java index a2fbf769574..94e22784f05 100644 --- a/spring-context/src/test/java/org/springframework/resilience/ConcurrencyLimitTests.java +++ b/spring-context/src/test/java/org/springframework/resilience/ConcurrencyLimitTests.java @@ -59,7 +59,8 @@ class ConcurrencyLimitTests { futures.add(CompletableFuture.runAsync(proxy::concurrentOperation)); } CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); - assertThat(target.counter).hasValue(0); + assertThat(target.current).hasValue(0); + assertThat(target.counter).hasValue(10); } @Test @@ -166,10 +167,12 @@ class ConcurrencyLimitTests { static class NonAnnotatedBean { + final AtomicInteger current = new AtomicInteger(); + final AtomicInteger counter = new AtomicInteger(); public void concurrentOperation() { - if (counter.incrementAndGet() > 2) { + if (current.incrementAndGet() > 2) { throw new IllegalStateException(); } try { @@ -178,7 +181,8 @@ class ConcurrencyLimitTests { catch (InterruptedException ex) { throw new IllegalStateException(ex); } - counter.decrementAndGet(); + current.decrementAndGet(); + counter.incrementAndGet(); } } diff --git a/spring-core/src/main/java/org/springframework/core/task/SyncTaskExecutor.java b/spring-core/src/main/java/org/springframework/core/task/SyncTaskExecutor.java index a56f7bd922a..df34deeaad1 100644 --- a/spring-core/src/main/java/org/springframework/core/task/SyncTaskExecutor.java +++ b/spring-core/src/main/java/org/springframework/core/task/SyncTaskExecutor.java @@ -19,12 +19,13 @@ package org.springframework.core.task; import java.io.Serializable; import org.springframework.util.Assert; +import org.springframework.util.ConcurrencyThrottleSupport; /** * {@link TaskExecutor} implementation that executes each task synchronously - * in the calling thread. - * - *

Mainly intended for testing scenarios. + * in the calling thread. This can be used for testing purposes but also for + * bounded execution in a Virtual Threads setup, relying on concurrency throttling + * as inherited from the base class: see {@link #setConcurrencyLimit} (as of 7.0). * *

Execution in the calling thread does have the advantage of participating * in its thread context, for example the thread context class loader or the @@ -37,17 +38,52 @@ import org.springframework.util.Assert; * @see SimpleAsyncTaskExecutor */ @SuppressWarnings("serial") -public class SyncTaskExecutor implements TaskExecutor, Serializable { +public class SyncTaskExecutor extends ConcurrencyThrottleSupport implements TaskExecutor, Serializable { /** - * Executes the given {@code task} synchronously, through direct - * invocation of it's {@link Runnable#run() run()} method. - * @throws IllegalArgumentException if the given {@code task} is {@code null} + * Execute the given {@code task} synchronously, through direct + * invocation of its {@link Runnable#run() run()} method. + * @throws RuntimeException if propagated from the given {@code Runnable} */ @Override public void execute(Runnable task) { - Assert.notNull(task, "Runnable must not be null"); - task.run(); + Assert.notNull(task, "Task must not be null"); + if (isThrottleActive()) { + beforeAccess(); + try { + task.run(); + } + finally { + afterAccess(); + } + } + else { + task.run(); + } + } + + /** + * Execute the given {@code task} synchronously, through direct + * invocation of its {@link TaskCallback#call() call()} method. + * @param the returned value type, if any + * @param the exception propagated, if any + * @throws E if propagated from the given {@code TaskCallback} + * @since 7.0 + */ + public V execute(TaskCallback task) throws E { + Assert.notNull(task, "Task must not be null"); + if (isThrottleActive()) { + beforeAccess(); + try { + return task.call(); + } + finally { + afterAccess(); + } + } + else { + return task.call(); + } } } diff --git a/spring-core/src/main/java/org/springframework/core/task/TaskCallback.java b/spring-core/src/main/java/org/springframework/core/task/TaskCallback.java new file mode 100644 index 00000000000..e2e8f6feda3 --- /dev/null +++ b/spring-core/src/main/java/org/springframework/core/task/TaskCallback.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.core.task; + +import java.util.concurrent.Callable; + +/** + * Variant of {@link Callable} with a flexible exception signature + * that can be adapted in the {@link SyncTaskExecutor#execute(TaskCallback)} + * method signature for propagating specific exception types only. + * + *

An implementation of this interface can also be passed into any + * {@code Callback}-based method such as {@link AsyncTaskExecutor#submit(Callable)} + * or {@link AsyncTaskExecutor#submitCompletable(Callable)}. It is just capable + * of adapting to flexible exception propagation in caller signatures as well. + * + * @author Juergen Hoeller + * @since 7.0 + * @param the returned value type, if any + * @param the exception propagated, if any + * @see SyncTaskExecutor#execute(TaskCallback) + */ +public interface TaskCallback extends Callable { + + @Override + V call() throws E; + +} diff --git a/spring-core/src/test/java/org/springframework/core/task/SyncTaskExecutorTests.java b/spring-core/src/test/java/org/springframework/core/task/SyncTaskExecutorTests.java new file mode 100644 index 00000000000..4fe8385ab53 --- /dev/null +++ b/spring-core/src/test/java/org/springframework/core/task/SyncTaskExecutorTests.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.core.task; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIOException; +import static org.assertj.core.api.Assertions.assertThatNoException; + +/** + * @author Juergen Hoeller + * @since 7.0 + */ +class SyncTaskExecutorTests { + + @Test + void plainExecution() { + SyncTaskExecutor taskExecutor = new SyncTaskExecutor(); + + ConcurrentClass target = new ConcurrentClass(); + assertThatNoException().isThrownBy(() -> taskExecutor.execute(target::concurrentOperation)); + assertThat(taskExecutor.execute(target::concurrentOperationWithResult)).isEqualTo("result"); + assertThatIOException().isThrownBy(() -> taskExecutor.execute(target::concurrentOperationWithException)); + } + + @Test + void withConcurrencyLimit() { + SyncTaskExecutor taskExecutor = new SyncTaskExecutor(); + taskExecutor.setConcurrencyLimit(2); + + ConcurrentClass target = new ConcurrentClass(); + List> futures = new ArrayList<>(10); + for (int i = 0; i < 10; i++) { + futures.add(CompletableFuture.runAsync(() -> taskExecutor.execute(target::concurrentOperation))); + } + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + assertThat(target.current).hasValue(0); + assertThat(target.counter).hasValue(10); + } + + @Test + void withConcurrencyLimitAndResult() { + SyncTaskExecutor taskExecutor = new SyncTaskExecutor(); + taskExecutor.setConcurrencyLimit(2); + + ConcurrentClass target = new ConcurrentClass(); + List> futures = new ArrayList<>(10); + for (int i = 0; i < 10; i++) { + futures.add(CompletableFuture.runAsync(() -> + assertThat(taskExecutor.execute(target::concurrentOperationWithResult)).isEqualTo("result"))); + } + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + assertThat(target.current).hasValue(0); + assertThat(target.counter).hasValue(10); + } + + @Test + void withConcurrencyLimitAndException() { + SyncTaskExecutor taskExecutor = new SyncTaskExecutor(); + taskExecutor.setConcurrencyLimit(2); + + ConcurrentClass target = new ConcurrentClass(); + List> futures = new ArrayList<>(10); + for (int i = 0; i < 10; i++) { + futures.add(CompletableFuture.runAsync(() -> + assertThatIOException().isThrownBy(() -> taskExecutor.execute(target::concurrentOperationWithException)))); + } + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + assertThat(target.current).hasValue(0); + assertThat(target.counter).hasValue(10); + } + + + static class ConcurrentClass { + + final AtomicInteger current = new AtomicInteger(); + + final AtomicInteger counter = new AtomicInteger(); + + public void concurrentOperation() { + if (current.incrementAndGet() > 2) { + throw new IllegalStateException(); + } + try { + Thread.sleep(10); + } + catch (InterruptedException ex) { + throw new IllegalStateException(ex); + } + current.decrementAndGet(); + counter.incrementAndGet(); + } + + public String concurrentOperationWithResult() { + concurrentOperation(); + return "result"; + } + + public String concurrentOperationWithException() throws IOException { + concurrentOperation(); + throw new IOException(); + } + } + +}