Harmonize WebSocket message broker to use Executor

This commit harmonizes the configuration of the WebSocket message
broker to use Executor rather than TaskExecutor as only the former
is enforced. This lets custom configuration to use a wider range
of implementations.

Closes gh-32129
This commit is contained in:
Stéphane Nicoll 2024-01-26 11:30:22 +01:00
parent 2fc8b13dd5
commit f526b23fd7
5 changed files with 78 additions and 79 deletions

View File

@ -21,6 +21,7 @@ import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.Supplier;
import org.springframework.beans.factory.BeanInitializationException;
@ -30,7 +31,6 @@ import org.springframework.context.ApplicationContextAware;
import org.springframework.context.SmartLifecycle;
import org.springframework.context.annotation.Bean;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.converter.ByteArrayMessageConverter;
@ -158,7 +158,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
@Bean
public AbstractSubscribableChannel clientInboundChannel(
@Qualifier("clientInboundChannelExecutor") TaskExecutor executor) {
@Qualifier("clientInboundChannelExecutor") Executor executor) {
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(executor);
channel.setLogger(SimpLogging.forLog(channel.getLogger()));
@ -170,9 +170,9 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
}
@Bean
public TaskExecutor clientInboundChannelExecutor() {
public Executor clientInboundChannelExecutor() {
ChannelRegistration registration = getClientInboundChannelRegistration();
TaskExecutor executor = getTaskExecutor(registration, "clientInboundChannel-", this::defaultTaskExecutor);
Executor executor = getExecutor(registration, "clientInboundChannel-", this::defaultExecutor);
if (executor instanceof ExecutorConfigurationSupport executorSupport) {
executorSupport.setPhase(getPhase());
}
@ -209,7 +209,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
@Bean
public AbstractSubscribableChannel clientOutboundChannel(
@Qualifier("clientOutboundChannelExecutor") TaskExecutor executor) {
@Qualifier("clientOutboundChannelExecutor") Executor executor) {
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(executor);
channel.setLogger(SimpLogging.forLog(channel.getLogger()));
@ -221,9 +221,9 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
}
@Bean
public TaskExecutor clientOutboundChannelExecutor() {
public Executor clientOutboundChannelExecutor() {
ChannelRegistration registration = getClientOutboundChannelRegistration();
TaskExecutor executor = getTaskExecutor(registration, "clientOutboundChannel-", this::defaultTaskExecutor);
Executor executor = getExecutor(registration, "clientOutboundChannel-", this::defaultExecutor);
if (executor instanceof ExecutorConfigurationSupport executorSupport) {
executorSupport.setPhase(getPhase());
}
@ -250,11 +250,11 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
@Bean
public AbstractSubscribableChannel brokerChannel(
AbstractSubscribableChannel clientInboundChannel, AbstractSubscribableChannel clientOutboundChannel,
@Qualifier("brokerChannelExecutor") TaskExecutor executor) {
@Qualifier("brokerChannelExecutor") Executor executor) {
MessageBrokerRegistry registry = getBrokerRegistry(clientInboundChannel, clientOutboundChannel);
ChannelRegistration registration = registry.getBrokerChannelRegistration();
ExecutorSubscribableChannel channel = (registration.hasTaskExecutor() ?
ExecutorSubscribableChannel channel = (registration.hasExecutor() ?
new ExecutorSubscribableChannel(executor) : new ExecutorSubscribableChannel());
registration.interceptors(new ImmutableMessageChannelInterceptor());
channel.setLogger(SimpLogging.forLog(channel.getLogger()));
@ -263,18 +263,18 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
}
@Bean
public TaskExecutor brokerChannelExecutor(
public Executor brokerChannelExecutor(
AbstractSubscribableChannel clientInboundChannel, AbstractSubscribableChannel clientOutboundChannel) {
MessageBrokerRegistry registry = getBrokerRegistry(clientInboundChannel, clientOutboundChannel);
ChannelRegistration registration = registry.getBrokerChannelRegistration();
TaskExecutor executor = getTaskExecutor(registration, "brokerChannel-", () -> {
Executor executor = getExecutor(registration, "brokerChannel-", () -> {
// Should never be used
ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
threadPoolTaskExecutor.setCorePoolSize(0);
threadPoolTaskExecutor.setMaxPoolSize(1);
threadPoolTaskExecutor.setQueueCapacity(0);
return threadPoolTaskExecutor;
ThreadPoolTaskExecutor fallbackExecutor = new ThreadPoolTaskExecutor();
fallbackExecutor.setCorePoolSize(0);
fallbackExecutor.setMaxPoolSize(1);
fallbackExecutor.setQueueCapacity(0);
return fallbackExecutor;
});
if (executor instanceof ExecutorConfigurationSupport executorSupport) {
executorSupport.setPhase(getPhase());
@ -282,19 +282,19 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
return executor;
}
private TaskExecutor defaultTaskExecutor() {
private Executor defaultExecutor() {
return new TaskExecutorRegistration().getTaskExecutor();
}
private static TaskExecutor getTaskExecutor(ChannelRegistration registration,
String threadNamePrefix, Supplier<TaskExecutor> fallback) {
private static Executor getExecutor(ChannelRegistration registration,
String threadNamePrefix, Supplier<Executor> fallback) {
return registration.getTaskExecutor(fallback,
return registration.getExecutor(fallback,
executor -> setThreadNamePrefix(executor, threadNamePrefix));
}
private static void setThreadNamePrefix(TaskExecutor taskExecutor, String name) {
if (taskExecutor instanceof CustomizableThreadCreator ctc) {
private static void setThreadNamePrefix(Executor executor, String name) {
if (executor instanceof CustomizableThreadCreator ctc) {
ctc.setThreadNamePrefix(name);
}
}

View File

@ -19,10 +19,10 @@ package org.springframework.messaging.simp.config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
@ -41,7 +41,7 @@ public class ChannelRegistration {
private TaskExecutorRegistration registration;
@Nullable
private TaskExecutor executor;
private Executor executor;
private final List<ChannelInterceptor> interceptors = new ArrayList<>();
@ -67,14 +67,14 @@ public class ChannelRegistration {
}
/**
* Configure the given {@link TaskExecutor} for this message channel,
* Configure the given {@link Executor} for this message channel,
* taking precedence over a {@linkplain #taskExecutor() task executor
* registration} if any.
* @param taskExecutor the task executor to use
* @param executor the executor to use
* @since 6.1.4
*/
public ChannelRegistration executor(TaskExecutor taskExecutor) {
this.executor = taskExecutor;
public ChannelRegistration executor(Executor executor) {
this.executor = executor;
return this;
}
@ -89,7 +89,7 @@ public class ChannelRegistration {
}
protected boolean hasTaskExecutor() {
protected boolean hasExecutor() {
return (this.registration != null || this.executor != null);
}
@ -98,18 +98,17 @@ public class ChannelRegistration {
}
/**
* Return the {@link TaskExecutor} to use. If no task executor has been
* configured, the {@code fallback} supplier is used to provide a fallback
* instance.
* Return the {@link Executor} to use. If no executor has been configured,
* the {@code fallback} supplier is used to provide a fallback instance.
* <p>
* If the {@link TaskExecutor} to use is suitable for further customizations,
* If the {@link Executor} to use is suitable for further customizations,
* the {@code customizer} consumer is invoked.
* @param fallback a supplier of a fallback task executor in case none is configured
* @param fallback a supplier of a fallback executor in case none is configured
* @param customizer further customizations
* @return the task executor to use
* @since 6.1.4
* @return the executor to use
* @since 6.2
*/
protected TaskExecutor getTaskExecutor(Supplier<TaskExecutor> fallback, Consumer<TaskExecutor> customizer) {
protected Executor getExecutor(Supplier<Executor> fallback, Consumer<Executor> customizer) {
if (this.executor != null) {
return this.executor;
}
@ -119,9 +118,9 @@ public class ChannelRegistration {
return registeredTaskExecutor;
}
else {
TaskExecutor taskExecutor = fallback.get();
customizer.accept(taskExecutor);
return taskExecutor;
Executor fallbackExecutor = fallback.get();
customizer.accept(fallbackExecutor);
return fallbackExecutor;
}
}

View File

@ -16,12 +16,12 @@
package org.springframework.messaging.simp.config;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.junit.jupiter.api.Test;
import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
@ -38,20 +38,20 @@ import static org.mockito.Mockito.verifyNoInteractions;
*/
class ChannelRegistrationTests {
private final Supplier<TaskExecutor> fallback = mock();
private final Supplier<Executor> fallback = mock();
private final Consumer<TaskExecutor> customizer = mock();
private final Consumer<Executor> customizer = mock();
@Test
void emptyRegistrationUsesFallback() {
TaskExecutor fallbackTaskExecutor = mock(TaskExecutor.class);
given(this.fallback.get()).willReturn(fallbackTaskExecutor);
Executor fallbackExecutor = mock(Executor.class);
given(this.fallback.get()).willReturn(fallbackExecutor);
ChannelRegistration registration = new ChannelRegistration();
assertThat(registration.hasTaskExecutor()).isFalse();
TaskExecutor actual = registration.getTaskExecutor(this.fallback, this.customizer);
assertThat(actual).isSameAs(fallbackTaskExecutor);
assertThat(registration.hasExecutor()).isFalse();
Executor actual = registration.getExecutor(this.fallback, this.customizer);
assertThat(actual).isSameAs(fallbackExecutor);
verify(this.fallback).get();
verify(this.customizer).accept(fallbackTaskExecutor);
verify(this.customizer).accept(fallbackExecutor);
}
@Test
@ -65,45 +65,45 @@ class ChannelRegistrationTests {
void taskRegistrationCreatesDefaultInstance() {
ChannelRegistration registration = new ChannelRegistration();
registration.taskExecutor();
assertThat(registration.hasTaskExecutor()).isTrue();
TaskExecutor taskExecutor = registration.getTaskExecutor(this.fallback, this.customizer);
assertThat(taskExecutor).isInstanceOf(ThreadPoolTaskExecutor.class);
assertThat(registration.hasExecutor()).isTrue();
Executor executor = registration.getExecutor(this.fallback, this.customizer);
assertThat(executor).isInstanceOf(ThreadPoolTaskExecutor.class);
verifyNoInteractions(this.fallback);
verify(this.customizer).accept(taskExecutor);
verify(this.customizer).accept(executor);
}
@Test
void taskRegistrationWithExistingThreadPoolTaskExecutor() {
ThreadPoolTaskExecutor existingTaskExecutor = mock(ThreadPoolTaskExecutor.class);
ThreadPoolTaskExecutor existingExecutor = mock(ThreadPoolTaskExecutor.class);
ChannelRegistration registration = new ChannelRegistration();
registration.taskExecutor(existingTaskExecutor);
assertThat(registration.hasTaskExecutor()).isTrue();
TaskExecutor taskExecutor = registration.getTaskExecutor(this.fallback, this.customizer);
assertThat(taskExecutor).isSameAs(existingTaskExecutor);
registration.taskExecutor(existingExecutor);
assertThat(registration.hasExecutor()).isTrue();
Executor executor = registration.getExecutor(this.fallback, this.customizer);
assertThat(executor).isSameAs(existingExecutor);
verifyNoInteractions(this.fallback);
verify(this.customizer).accept(taskExecutor);
verify(this.customizer).accept(executor);
}
@Test
void configureExecutor() {
ChannelRegistration registration = new ChannelRegistration();
TaskExecutor taskExecutor = mock(TaskExecutor.class);
registration.executor(taskExecutor);
assertThat(registration.hasTaskExecutor()).isTrue();
TaskExecutor taskExecutor1 = registration.getTaskExecutor(this.fallback, this.customizer);
assertThat(taskExecutor1).isSameAs(taskExecutor);
Executor executor = mock(Executor.class);
registration.executor(executor);
assertThat(registration.hasExecutor()).isTrue();
Executor actualExecutor = registration.getExecutor(this.fallback, this.customizer);
assertThat(actualExecutor).isSameAs(executor);
verifyNoInteractions(this.fallback, this.customizer);
}
@Test
void configureExecutorTakesPrecedenceOverTaskRegistration() {
ChannelRegistration registration = new ChannelRegistration();
TaskExecutor taskExecutor = mock(TaskExecutor.class);
registration.executor(taskExecutor);
Executor executor = mock(Executor.class);
registration.executor(executor);
ThreadPoolTaskExecutor ignored = mock(ThreadPoolTaskExecutor.class);
registration.taskExecutor(ignored);
assertThat(registration.hasTaskExecutor()).isTrue();
assertThat(registration.getTaskExecutor(this.fallback, this.customizer)).isSameAs(taskExecutor);
assertThat(registration.hasExecutor()).isTrue();
assertThat(registration.getExecutor(this.fallback, this.customizer)).isSameAs(executor);
verifyNoInteractions(ignored, this.fallback, this.customizer);
}

View File

@ -22,6 +22,7 @@ import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import org.junit.jupiter.api.Test;
@ -31,7 +32,6 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.support.StaticApplicationContext;
import org.springframework.core.Ordered;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
@ -599,20 +599,20 @@ class MessageBrokerConfigurationTests {
@Override
@Bean
public AbstractSubscribableChannel clientInboundChannel(TaskExecutor clientInboundChannelExecutor) {
public AbstractSubscribableChannel clientInboundChannel(Executor clientInboundChannelExecutor) {
return new TestChannel();
}
@Override
@Bean
public AbstractSubscribableChannel clientOutboundChannel(TaskExecutor clientOutboundChannelExecutor) {
public AbstractSubscribableChannel clientOutboundChannel(Executor clientOutboundChannelExecutor) {
return new TestChannel();
}
@Override
@Bean
public AbstractSubscribableChannel brokerChannel(AbstractSubscribableChannel clientInboundChannel,
AbstractSubscribableChannel clientOutboundChannel, TaskExecutor brokerChannelExecutor) {
AbstractSubscribableChannel clientOutboundChannel, Executor brokerChannelExecutor) {
return new TestChannel();
}
}
@ -688,21 +688,21 @@ class MessageBrokerConfigurationTests {
@Override
@Bean
public AbstractSubscribableChannel clientInboundChannel(TaskExecutor clientInboundChannelExecutor) {
public AbstractSubscribableChannel clientInboundChannel(Executor clientInboundChannelExecutor) {
// synchronous
return new ExecutorSubscribableChannel(null);
}
@Override
@Bean
public AbstractSubscribableChannel clientOutboundChannel(TaskExecutor clientOutboundChannelExecutor) {
public AbstractSubscribableChannel clientOutboundChannel(Executor clientOutboundChannelExecutor) {
return new TestChannel();
}
@Override
@Bean
public AbstractSubscribableChannel brokerChannel(AbstractSubscribableChannel clientInboundChannel,
AbstractSubscribableChannel clientOutboundChannel, TaskExecutor brokerChannelExecutor) {
AbstractSubscribableChannel clientOutboundChannel, Executor brokerChannelExecutor) {
// synchronous
return new ExecutorSubscribableChannel(null);
}

View File

@ -20,6 +20,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.function.Consumer;
@ -29,7 +30,6 @@ import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.handler.annotation.MessageMapping;
@ -318,7 +318,7 @@ class WebSocketMessageBrokerConfigurationSupportTests {
@Override
@Bean
public AbstractSubscribableChannel clientInboundChannel(TaskExecutor clientInboundChannelExecutor) {
public AbstractSubscribableChannel clientInboundChannel(Executor clientInboundChannelExecutor) {
TestChannel channel = new TestChannel();
channel.setInterceptors(super.clientInboundChannel(clientInboundChannelExecutor).getInterceptors());
return channel;
@ -326,7 +326,7 @@ class WebSocketMessageBrokerConfigurationSupportTests {
@Override
@Bean
public AbstractSubscribableChannel clientOutboundChannel(TaskExecutor clientOutboundChannelExecutor) {
public AbstractSubscribableChannel clientOutboundChannel(Executor clientOutboundChannelExecutor) {
TestChannel channel = new TestChannel();
channel.setInterceptors(super.clientOutboundChannel(clientOutboundChannelExecutor).getInterceptors());
return channel;
@ -334,7 +334,7 @@ class WebSocketMessageBrokerConfigurationSupportTests {
@Override
public AbstractSubscribableChannel brokerChannel(AbstractSubscribableChannel clientInboundChannel,
AbstractSubscribableChannel clientOutboundChannel, TaskExecutor brokerChannelExecutor) {
AbstractSubscribableChannel clientOutboundChannel, Executor brokerChannelExecutor) {
TestChannel channel = new TestChannel();
channel.setInterceptors(super.brokerChannel(clientInboundChannel, clientOutboundChannel, brokerChannelExecutor).getInterceptors());
return channel;