From 88157880044f8af84ca1d81012f3de0c55bbc4e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Wed, 24 Jan 2024 14:34:16 +0100 Subject: [PATCH] Allow an existing TaskExecutor to be configured in ChannelRegistration This commit introduces a new method to configure an existing TaskExecutor in ChannelRegistration. Contrary to TaskExecutorRegistration, a ThreadPoolTaskExecutor is not necessary, and it can't be further configured. This includes the thread name prefix. Closes gh-32081 --- .../AbstractMessageBrokerConfiguration.java | 48 ++++--- .../simp/config/ChannelRegistration.java | 50 +++++++- .../simp/config/ChannelRegistrationTests.java | 121 ++++++++++++++++++ 3 files changed, 198 insertions(+), 21 deletions(-) create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/config/ChannelRegistrationTests.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java index f2a74288502..87286aaab6d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java @@ -21,6 +21,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import org.springframework.beans.factory.BeanInitializationException; import org.springframework.beans.factory.annotation.Qualifier; @@ -62,6 +63,7 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.CustomizableThreadCreator; import org.springframework.util.MimeTypeUtils; import org.springframework.util.PathMatcher; import org.springframework.util.StringUtils; @@ -164,10 +166,8 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC @Bean public TaskExecutor clientInboundChannelExecutor() { - TaskExecutorRegistration reg = getClientInboundChannelRegistration().taskExecutor(); - ThreadPoolTaskExecutor executor = reg.getTaskExecutor(); - executor.setThreadNamePrefix("clientInboundChannel-"); - return executor; + return getTaskExecutor(getClientInboundChannelRegistration(), + "clientInboundChannel-", this::defaultTaskExecutor); } protected final ChannelRegistration getClientInboundChannelRegistration() { @@ -202,10 +202,8 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC @Bean public TaskExecutor clientOutboundChannelExecutor() { - TaskExecutorRegistration reg = getClientOutboundChannelRegistration().taskExecutor(); - ThreadPoolTaskExecutor executor = reg.getTaskExecutor(); - executor.setThreadNamePrefix("clientOutboundChannel-"); - return executor; + return getTaskExecutor(getClientOutboundChannelRegistration(), + "clientOutboundChannel-", this::defaultTaskExecutor); } protected final ChannelRegistration getClientOutboundChannelRegistration() { @@ -246,19 +244,31 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC MessageBrokerRegistry registry = getBrokerRegistry(clientInboundChannel, clientOutboundChannel); ChannelRegistration registration = registry.getBrokerChannelRegistration(); - ThreadPoolTaskExecutor executor; - if (registration.hasTaskExecutor()) { - executor = registration.taskExecutor().getTaskExecutor(); - } - else { + return getTaskExecutor(registration, "brokerChannel-", () -> { // Should never be used - executor = new ThreadPoolTaskExecutor(); - executor.setCorePoolSize(0); - executor.setMaxPoolSize(1); - executor.setQueueCapacity(0); + ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor(); + threadPoolTaskExecutor.setCorePoolSize(0); + threadPoolTaskExecutor.setMaxPoolSize(1); + threadPoolTaskExecutor.setQueueCapacity(0); + return threadPoolTaskExecutor; + }); + } + + private static TaskExecutor getTaskExecutor(ChannelRegistration registration, + String threadNamePrefix, Supplier fallback) { + + return registration.getTaskExecutor(fallback, + executor -> setThreadNamePrefix(executor, threadNamePrefix)); + } + + private TaskExecutor defaultTaskExecutor() { + return new TaskExecutorRegistration().getTaskExecutor(); + } + + private static void setThreadNamePrefix(TaskExecutor taskExecutor, String name) { + if (taskExecutor instanceof CustomizableThreadCreator ctc) { + ctc.setThreadNamePrefix(name); } - executor.setThreadNamePrefix("brokerChannel-"); - return executor; } /** diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ChannelRegistration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ChannelRegistration.java index baaa9f8bbd0..e8f0626def7 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ChannelRegistration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ChannelRegistration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,10 @@ package org.springframework.messaging.simp.config; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.springframework.core.task.TaskExecutor; import org.springframework.lang.Nullable; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; @@ -29,6 +32,7 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; * {@link org.springframework.messaging.MessageChannel}. * * @author Rossen Stoyanchev + * @author Stephane Nicoll * @since 4.0 */ public class ChannelRegistration { @@ -36,6 +40,9 @@ public class ChannelRegistration { @Nullable private TaskExecutorRegistration registration; + @Nullable + private TaskExecutor executor; + private final List interceptors = new ArrayList<>(); @@ -59,6 +66,18 @@ public class ChannelRegistration { return this.registration; } + /** + * Configure the given {@link TaskExecutor} for this message channel, + * taking precedence over a {@linkplain #taskExecutor() task executor + * registration} if any. + * @param taskExecutor the task executor to use + * @since 6.1.4 + */ + public ChannelRegistration executor(TaskExecutor taskExecutor) { + this.executor = taskExecutor; + return this; + } + /** * Configure the given interceptors for this message channel, * adding them to the channel's current list of interceptors. @@ -71,13 +90,40 @@ public class ChannelRegistration { protected boolean hasTaskExecutor() { - return (this.registration != null); + return (this.registration != null || this.executor != null); } protected boolean hasInterceptors() { return !this.interceptors.isEmpty(); } + /** + * Return the {@link TaskExecutor} to use. If no task executor has been + * configured, the {@code fallback} supplier is used to provide a fallback + * instance. + *

+ * If the {@link TaskExecutor} to use is suitable for further customizations, + * the {@code customizer} consumer is invoked. + * @param fallback a supplier of a fallback task executor in case none is configured + * @param customizer further customizations + * @return the task executor to use + */ + protected TaskExecutor getTaskExecutor(Supplier fallback, Consumer customizer) { + if (this.executor != null) { + return this.executor; + } + else if (this.registration != null) { + ThreadPoolTaskExecutor registeredTaskExecutor = this.registration.getTaskExecutor(); + customizer.accept(registeredTaskExecutor); + return registeredTaskExecutor; + } + else { + TaskExecutor taskExecutor = fallback.get(); + customizer.accept(taskExecutor); + return taskExecutor; + } + } + protected List getInterceptors() { return this.interceptors; } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ChannelRegistrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ChannelRegistrationTests.java new file mode 100644 index 00000000000..dc392e2437e --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ChannelRegistrationTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.config; + +import java.util.function.Consumer; +import java.util.function.Supplier; + +import org.junit.jupiter.api.Test; + +import org.springframework.core.task.TaskExecutor; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link ChannelRegistration}. + * + * @author Stephane Nicoll + */ +class ChannelRegistrationTests { + + private final Supplier fallback = mock(); + + private final Consumer customizer = mock(); + + @Test + void emptyRegistrationUsesFallback() { + TaskExecutor fallbackTaskExecutor = mock(TaskExecutor.class); + given(this.fallback.get()).willReturn(fallbackTaskExecutor); + ChannelRegistration registration = new ChannelRegistration(); + assertThat(registration.hasTaskExecutor()).isFalse(); + TaskExecutor actual = registration.getTaskExecutor(this.fallback, this.customizer); + assertThat(actual).isSameAs(fallbackTaskExecutor); + verify(this.fallback).get(); + verify(this.customizer).accept(fallbackTaskExecutor); + } + + @Test + void emptyRegistrationDoesNotHaveInterceptors() { + ChannelRegistration registration = new ChannelRegistration(); + assertThat(registration.hasInterceptors()).isFalse(); + assertThat(registration.getInterceptors()).isEmpty(); + } + + @Test + void taskRegistrationCreatesDefaultInstance() { + ChannelRegistration registration = new ChannelRegistration(); + registration.taskExecutor(); + assertThat(registration.hasTaskExecutor()).isTrue(); + TaskExecutor taskExecutor = registration.getTaskExecutor(this.fallback, this.customizer); + assertThat(taskExecutor).isInstanceOf(ThreadPoolTaskExecutor.class); + verifyNoInteractions(this.fallback); + verify(this.customizer).accept(taskExecutor); + } + + @Test + void taskRegistrationWithExistingThreadPoolTaskExecutor() { + ThreadPoolTaskExecutor existingTaskExecutor = mock(ThreadPoolTaskExecutor.class); + ChannelRegistration registration = new ChannelRegistration(); + registration.taskExecutor(existingTaskExecutor); + assertThat(registration.hasTaskExecutor()).isTrue(); + TaskExecutor taskExecutor = registration.getTaskExecutor(this.fallback, this.customizer); + assertThat(taskExecutor).isSameAs(existingTaskExecutor); + verifyNoInteractions(this.fallback); + verify(this.customizer).accept(taskExecutor); + } + + @Test + void configureExecutor() { + ChannelRegistration registration = new ChannelRegistration(); + TaskExecutor taskExecutor = mock(TaskExecutor.class); + registration.executor(taskExecutor); + assertThat(registration.hasTaskExecutor()).isTrue(); + TaskExecutor taskExecutor1 = registration.getTaskExecutor(this.fallback, this.customizer); + assertThat(taskExecutor1).isSameAs(taskExecutor); + verifyNoInteractions(this.fallback, this.customizer); + } + + @Test + void configureExecutorTakesPrecedenceOverTaskRegistration() { + ChannelRegistration registration = new ChannelRegistration(); + TaskExecutor taskExecutor = mock(TaskExecutor.class); + registration.executor(taskExecutor); + ThreadPoolTaskExecutor ignored = mock(ThreadPoolTaskExecutor.class); + registration.taskExecutor(ignored); + assertThat(registration.hasTaskExecutor()).isTrue(); + assertThat(registration.getTaskExecutor(this.fallback, this.customizer)).isSameAs(taskExecutor); + verifyNoInteractions(ignored, this.fallback, this.customizer); + + } + + @Test + void configureInterceptors() { + ChannelRegistration registration = new ChannelRegistration(); + ChannelInterceptor interceptor1 = mock(ChannelInterceptor.class); + registration.interceptors(interceptor1); + ChannelInterceptor interceptor2 = mock(ChannelInterceptor.class); + registration.interceptors(interceptor2); + assertThat(registration.getInterceptors()).containsExactly(interceptor1, interceptor2); + } + +}