From f526b23fd728ea3eb6cf7245142f4563df34dbbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Fri, 26 Jan 2024 11:30:22 +0100 Subject: [PATCH] Harmonize WebSocket message broker to use Executor This commit harmonizes the configuration of the WebSocket message broker to use Executor rather than TaskExecutor as only the former is enforced. This lets custom configuration to use a wider range of implementations. Closes gh-32129 --- .../AbstractMessageBrokerConfiguration.java | 44 +++++++-------- .../simp/config/ChannelRegistration.java | 35 ++++++------ .../simp/config/ChannelRegistrationTests.java | 56 +++++++++---------- .../MessageBrokerConfigurationTests.java | 14 ++--- ...essageBrokerConfigurationSupportTests.java | 8 +-- 5 files changed, 78 insertions(+), 79 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java index b38cfc27d29..7439c8c8308 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java @@ -21,6 +21,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.Executor; import java.util.function.Supplier; import org.springframework.beans.factory.BeanInitializationException; @@ -30,7 +31,6 @@ import org.springframework.context.ApplicationContextAware; import org.springframework.context.SmartLifecycle; import org.springframework.context.annotation.Bean; import org.springframework.context.event.SmartApplicationListener; -import org.springframework.core.task.TaskExecutor; import org.springframework.lang.Nullable; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.converter.ByteArrayMessageConverter; @@ -158,7 +158,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC @Bean public AbstractSubscribableChannel clientInboundChannel( - @Qualifier("clientInboundChannelExecutor") TaskExecutor executor) { + @Qualifier("clientInboundChannelExecutor") Executor executor) { ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(executor); channel.setLogger(SimpLogging.forLog(channel.getLogger())); @@ -170,9 +170,9 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC } @Bean - public TaskExecutor clientInboundChannelExecutor() { + public Executor clientInboundChannelExecutor() { ChannelRegistration registration = getClientInboundChannelRegistration(); - TaskExecutor executor = getTaskExecutor(registration, "clientInboundChannel-", this::defaultTaskExecutor); + Executor executor = getExecutor(registration, "clientInboundChannel-", this::defaultExecutor); if (executor instanceof ExecutorConfigurationSupport executorSupport) { executorSupport.setPhase(getPhase()); } @@ -209,7 +209,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC @Bean public AbstractSubscribableChannel clientOutboundChannel( - @Qualifier("clientOutboundChannelExecutor") TaskExecutor executor) { + @Qualifier("clientOutboundChannelExecutor") Executor executor) { ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(executor); channel.setLogger(SimpLogging.forLog(channel.getLogger())); @@ -221,9 +221,9 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC } @Bean - public TaskExecutor clientOutboundChannelExecutor() { + public Executor clientOutboundChannelExecutor() { ChannelRegistration registration = getClientOutboundChannelRegistration(); - TaskExecutor executor = getTaskExecutor(registration, "clientOutboundChannel-", this::defaultTaskExecutor); + Executor executor = getExecutor(registration, "clientOutboundChannel-", this::defaultExecutor); if (executor instanceof ExecutorConfigurationSupport executorSupport) { executorSupport.setPhase(getPhase()); } @@ -250,11 +250,11 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC @Bean public AbstractSubscribableChannel brokerChannel( AbstractSubscribableChannel clientInboundChannel, AbstractSubscribableChannel clientOutboundChannel, - @Qualifier("brokerChannelExecutor") TaskExecutor executor) { + @Qualifier("brokerChannelExecutor") Executor executor) { MessageBrokerRegistry registry = getBrokerRegistry(clientInboundChannel, clientOutboundChannel); ChannelRegistration registration = registry.getBrokerChannelRegistration(); - ExecutorSubscribableChannel channel = (registration.hasTaskExecutor() ? + ExecutorSubscribableChannel channel = (registration.hasExecutor() ? new ExecutorSubscribableChannel(executor) : new ExecutorSubscribableChannel()); registration.interceptors(new ImmutableMessageChannelInterceptor()); channel.setLogger(SimpLogging.forLog(channel.getLogger())); @@ -263,18 +263,18 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC } @Bean - public TaskExecutor brokerChannelExecutor( + public Executor brokerChannelExecutor( AbstractSubscribableChannel clientInboundChannel, AbstractSubscribableChannel clientOutboundChannel) { MessageBrokerRegistry registry = getBrokerRegistry(clientInboundChannel, clientOutboundChannel); ChannelRegistration registration = registry.getBrokerChannelRegistration(); - TaskExecutor executor = getTaskExecutor(registration, "brokerChannel-", () -> { + Executor executor = getExecutor(registration, "brokerChannel-", () -> { // Should never be used - ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor(); - threadPoolTaskExecutor.setCorePoolSize(0); - threadPoolTaskExecutor.setMaxPoolSize(1); - threadPoolTaskExecutor.setQueueCapacity(0); - return threadPoolTaskExecutor; + ThreadPoolTaskExecutor fallbackExecutor = new ThreadPoolTaskExecutor(); + fallbackExecutor.setCorePoolSize(0); + fallbackExecutor.setMaxPoolSize(1); + fallbackExecutor.setQueueCapacity(0); + return fallbackExecutor; }); if (executor instanceof ExecutorConfigurationSupport executorSupport) { executorSupport.setPhase(getPhase()); @@ -282,19 +282,19 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC return executor; } - private TaskExecutor defaultTaskExecutor() { + private Executor defaultExecutor() { return new TaskExecutorRegistration().getTaskExecutor(); } - private static TaskExecutor getTaskExecutor(ChannelRegistration registration, - String threadNamePrefix, Supplier fallback) { + private static Executor getExecutor(ChannelRegistration registration, + String threadNamePrefix, Supplier fallback) { - return registration.getTaskExecutor(fallback, + return registration.getExecutor(fallback, executor -> setThreadNamePrefix(executor, threadNamePrefix)); } - private static void setThreadNamePrefix(TaskExecutor taskExecutor, String name) { - if (taskExecutor instanceof CustomizableThreadCreator ctc) { + private static void setThreadNamePrefix(Executor executor, String name) { + if (executor instanceof CustomizableThreadCreator ctc) { ctc.setThreadNamePrefix(name); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ChannelRegistration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ChannelRegistration.java index d2ef4bb1afa..be7502afdf7 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ChannelRegistration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ChannelRegistration.java @@ -19,10 +19,10 @@ package org.springframework.messaging.simp.config; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.function.Supplier; -import org.springframework.core.task.TaskExecutor; import org.springframework.lang.Nullable; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; @@ -41,7 +41,7 @@ public class ChannelRegistration { private TaskExecutorRegistration registration; @Nullable - private TaskExecutor executor; + private Executor executor; private final List interceptors = new ArrayList<>(); @@ -67,14 +67,14 @@ public class ChannelRegistration { } /** - * Configure the given {@link TaskExecutor} for this message channel, + * Configure the given {@link Executor} for this message channel, * taking precedence over a {@linkplain #taskExecutor() task executor * registration} if any. - * @param taskExecutor the task executor to use + * @param executor the executor to use * @since 6.1.4 */ - public ChannelRegistration executor(TaskExecutor taskExecutor) { - this.executor = taskExecutor; + public ChannelRegistration executor(Executor executor) { + this.executor = executor; return this; } @@ -89,7 +89,7 @@ public class ChannelRegistration { } - protected boolean hasTaskExecutor() { + protected boolean hasExecutor() { return (this.registration != null || this.executor != null); } @@ -98,18 +98,17 @@ public class ChannelRegistration { } /** - * Return the {@link TaskExecutor} to use. If no task executor has been - * configured, the {@code fallback} supplier is used to provide a fallback - * instance. + * Return the {@link Executor} to use. If no executor has been configured, + * the {@code fallback} supplier is used to provide a fallback instance. *

- * If the {@link TaskExecutor} to use is suitable for further customizations, + * If the {@link Executor} to use is suitable for further customizations, * the {@code customizer} consumer is invoked. - * @param fallback a supplier of a fallback task executor in case none is configured + * @param fallback a supplier of a fallback executor in case none is configured * @param customizer further customizations - * @return the task executor to use - * @since 6.1.4 + * @return the executor to use + * @since 6.2 */ - protected TaskExecutor getTaskExecutor(Supplier fallback, Consumer customizer) { + protected Executor getExecutor(Supplier fallback, Consumer customizer) { if (this.executor != null) { return this.executor; } @@ -119,9 +118,9 @@ public class ChannelRegistration { return registeredTaskExecutor; } else { - TaskExecutor taskExecutor = fallback.get(); - customizer.accept(taskExecutor); - return taskExecutor; + Executor fallbackExecutor = fallback.get(); + customizer.accept(fallbackExecutor); + return fallbackExecutor; } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ChannelRegistrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ChannelRegistrationTests.java index dc392e2437e..ea5ae896301 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ChannelRegistrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ChannelRegistrationTests.java @@ -16,12 +16,12 @@ package org.springframework.messaging.simp.config; +import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.function.Supplier; import org.junit.jupiter.api.Test; -import org.springframework.core.task.TaskExecutor; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; @@ -38,20 +38,20 @@ import static org.mockito.Mockito.verifyNoInteractions; */ class ChannelRegistrationTests { - private final Supplier fallback = mock(); + private final Supplier fallback = mock(); - private final Consumer customizer = mock(); + private final Consumer customizer = mock(); @Test void emptyRegistrationUsesFallback() { - TaskExecutor fallbackTaskExecutor = mock(TaskExecutor.class); - given(this.fallback.get()).willReturn(fallbackTaskExecutor); + Executor fallbackExecutor = mock(Executor.class); + given(this.fallback.get()).willReturn(fallbackExecutor); ChannelRegistration registration = new ChannelRegistration(); - assertThat(registration.hasTaskExecutor()).isFalse(); - TaskExecutor actual = registration.getTaskExecutor(this.fallback, this.customizer); - assertThat(actual).isSameAs(fallbackTaskExecutor); + assertThat(registration.hasExecutor()).isFalse(); + Executor actual = registration.getExecutor(this.fallback, this.customizer); + assertThat(actual).isSameAs(fallbackExecutor); verify(this.fallback).get(); - verify(this.customizer).accept(fallbackTaskExecutor); + verify(this.customizer).accept(fallbackExecutor); } @Test @@ -65,45 +65,45 @@ class ChannelRegistrationTests { void taskRegistrationCreatesDefaultInstance() { ChannelRegistration registration = new ChannelRegistration(); registration.taskExecutor(); - assertThat(registration.hasTaskExecutor()).isTrue(); - TaskExecutor taskExecutor = registration.getTaskExecutor(this.fallback, this.customizer); - assertThat(taskExecutor).isInstanceOf(ThreadPoolTaskExecutor.class); + assertThat(registration.hasExecutor()).isTrue(); + Executor executor = registration.getExecutor(this.fallback, this.customizer); + assertThat(executor).isInstanceOf(ThreadPoolTaskExecutor.class); verifyNoInteractions(this.fallback); - verify(this.customizer).accept(taskExecutor); + verify(this.customizer).accept(executor); } @Test void taskRegistrationWithExistingThreadPoolTaskExecutor() { - ThreadPoolTaskExecutor existingTaskExecutor = mock(ThreadPoolTaskExecutor.class); + ThreadPoolTaskExecutor existingExecutor = mock(ThreadPoolTaskExecutor.class); ChannelRegistration registration = new ChannelRegistration(); - registration.taskExecutor(existingTaskExecutor); - assertThat(registration.hasTaskExecutor()).isTrue(); - TaskExecutor taskExecutor = registration.getTaskExecutor(this.fallback, this.customizer); - assertThat(taskExecutor).isSameAs(existingTaskExecutor); + registration.taskExecutor(existingExecutor); + assertThat(registration.hasExecutor()).isTrue(); + Executor executor = registration.getExecutor(this.fallback, this.customizer); + assertThat(executor).isSameAs(existingExecutor); verifyNoInteractions(this.fallback); - verify(this.customizer).accept(taskExecutor); + verify(this.customizer).accept(executor); } @Test void configureExecutor() { ChannelRegistration registration = new ChannelRegistration(); - TaskExecutor taskExecutor = mock(TaskExecutor.class); - registration.executor(taskExecutor); - assertThat(registration.hasTaskExecutor()).isTrue(); - TaskExecutor taskExecutor1 = registration.getTaskExecutor(this.fallback, this.customizer); - assertThat(taskExecutor1).isSameAs(taskExecutor); + Executor executor = mock(Executor.class); + registration.executor(executor); + assertThat(registration.hasExecutor()).isTrue(); + Executor actualExecutor = registration.getExecutor(this.fallback, this.customizer); + assertThat(actualExecutor).isSameAs(executor); verifyNoInteractions(this.fallback, this.customizer); } @Test void configureExecutorTakesPrecedenceOverTaskRegistration() { ChannelRegistration registration = new ChannelRegistration(); - TaskExecutor taskExecutor = mock(TaskExecutor.class); - registration.executor(taskExecutor); + Executor executor = mock(Executor.class); + registration.executor(executor); ThreadPoolTaskExecutor ignored = mock(ThreadPoolTaskExecutor.class); registration.taskExecutor(ignored); - assertThat(registration.hasTaskExecutor()).isTrue(); - assertThat(registration.getTaskExecutor(this.fallback, this.customizer)).isSameAs(taskExecutor); + assertThat(registration.hasExecutor()).isTrue(); + assertThat(registration.getExecutor(this.fallback, this.customizer)).isSameAs(executor); verifyNoInteractions(ignored, this.fallback, this.customizer); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java index 34ca2761b55..3445bb0a32c 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; import org.junit.jupiter.api.Test; @@ -31,7 +32,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.support.StaticApplicationContext; import org.springframework.core.Ordered; -import org.springframework.core.task.TaskExecutor; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; @@ -599,20 +599,20 @@ class MessageBrokerConfigurationTests { @Override @Bean - public AbstractSubscribableChannel clientInboundChannel(TaskExecutor clientInboundChannelExecutor) { + public AbstractSubscribableChannel clientInboundChannel(Executor clientInboundChannelExecutor) { return new TestChannel(); } @Override @Bean - public AbstractSubscribableChannel clientOutboundChannel(TaskExecutor clientOutboundChannelExecutor) { + public AbstractSubscribableChannel clientOutboundChannel(Executor clientOutboundChannelExecutor) { return new TestChannel(); } @Override @Bean public AbstractSubscribableChannel brokerChannel(AbstractSubscribableChannel clientInboundChannel, - AbstractSubscribableChannel clientOutboundChannel, TaskExecutor brokerChannelExecutor) { + AbstractSubscribableChannel clientOutboundChannel, Executor brokerChannelExecutor) { return new TestChannel(); } } @@ -688,21 +688,21 @@ class MessageBrokerConfigurationTests { @Override @Bean - public AbstractSubscribableChannel clientInboundChannel(TaskExecutor clientInboundChannelExecutor) { + public AbstractSubscribableChannel clientInboundChannel(Executor clientInboundChannelExecutor) { // synchronous return new ExecutorSubscribableChannel(null); } @Override @Bean - public AbstractSubscribableChannel clientOutboundChannel(TaskExecutor clientOutboundChannelExecutor) { + public AbstractSubscribableChannel clientOutboundChannel(Executor clientOutboundChannelExecutor) { return new TestChannel(); } @Override @Bean public AbstractSubscribableChannel brokerChannel(AbstractSubscribableChannel clientInboundChannel, - AbstractSubscribableChannel clientOutboundChannel, TaskExecutor brokerChannelExecutor) { + AbstractSubscribableChannel clientOutboundChannel, Executor brokerChannelExecutor) { // synchronous return new ExecutorSubscribableChannel(null); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java index fce2e7f2957..e8a94d3e357 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.Executor; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.function.Consumer; @@ -29,7 +30,6 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.core.task.TaskExecutor; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.handler.annotation.MessageMapping; @@ -318,7 +318,7 @@ class WebSocketMessageBrokerConfigurationSupportTests { @Override @Bean - public AbstractSubscribableChannel clientInboundChannel(TaskExecutor clientInboundChannelExecutor) { + public AbstractSubscribableChannel clientInboundChannel(Executor clientInboundChannelExecutor) { TestChannel channel = new TestChannel(); channel.setInterceptors(super.clientInboundChannel(clientInboundChannelExecutor).getInterceptors()); return channel; @@ -326,7 +326,7 @@ class WebSocketMessageBrokerConfigurationSupportTests { @Override @Bean - public AbstractSubscribableChannel clientOutboundChannel(TaskExecutor clientOutboundChannelExecutor) { + public AbstractSubscribableChannel clientOutboundChannel(Executor clientOutboundChannelExecutor) { TestChannel channel = new TestChannel(); channel.setInterceptors(super.clientOutboundChannel(clientOutboundChannelExecutor).getInterceptors()); return channel; @@ -334,7 +334,7 @@ class WebSocketMessageBrokerConfigurationSupportTests { @Override public AbstractSubscribableChannel brokerChannel(AbstractSubscribableChannel clientInboundChannel, - AbstractSubscribableChannel clientOutboundChannel, TaskExecutor brokerChannelExecutor) { + AbstractSubscribableChannel clientOutboundChannel, Executor brokerChannelExecutor) { TestChannel channel = new TestChannel(); channel.setInterceptors(super.brokerChannel(clientInboundChannel, clientOutboundChannel, brokerChannelExecutor).getInterceptors()); return channel;