Allow an existing TaskExecutor to be configured in ChannelRegistration

This commit introduces a new method to configure an existing
TaskExecutor in ChannelRegistration. Contrary to
TaskExecutorRegistration, a ThreadPoolTaskExecutor is not necessary,
and it can't be further configured. This includes the thread name
prefix.

Closes gh-32081
This commit is contained in:
Stéphane Nicoll 2024-01-24 14:34:16 +01:00
parent b7e4fa16ca
commit 8815788004
3 changed files with 198 additions and 21 deletions

View File

@ -21,6 +21,7 @@ import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import org.springframework.beans.factory.BeanInitializationException;
import org.springframework.beans.factory.annotation.Qualifier;
@ -62,6 +63,7 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CustomizableThreadCreator;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.PathMatcher;
import org.springframework.util.StringUtils;
@ -164,10 +166,8 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
@Bean
public TaskExecutor clientInboundChannelExecutor() {
TaskExecutorRegistration reg = getClientInboundChannelRegistration().taskExecutor();
ThreadPoolTaskExecutor executor = reg.getTaskExecutor();
executor.setThreadNamePrefix("clientInboundChannel-");
return executor;
return getTaskExecutor(getClientInboundChannelRegistration(),
"clientInboundChannel-", this::defaultTaskExecutor);
}
protected final ChannelRegistration getClientInboundChannelRegistration() {
@ -202,10 +202,8 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
@Bean
public TaskExecutor clientOutboundChannelExecutor() {
TaskExecutorRegistration reg = getClientOutboundChannelRegistration().taskExecutor();
ThreadPoolTaskExecutor executor = reg.getTaskExecutor();
executor.setThreadNamePrefix("clientOutboundChannel-");
return executor;
return getTaskExecutor(getClientOutboundChannelRegistration(),
"clientOutboundChannel-", this::defaultTaskExecutor);
}
protected final ChannelRegistration getClientOutboundChannelRegistration() {
@ -246,19 +244,31 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
MessageBrokerRegistry registry = getBrokerRegistry(clientInboundChannel, clientOutboundChannel);
ChannelRegistration registration = registry.getBrokerChannelRegistration();
ThreadPoolTaskExecutor executor;
if (registration.hasTaskExecutor()) {
executor = registration.taskExecutor().getTaskExecutor();
}
else {
return getTaskExecutor(registration, "brokerChannel-", () -> {
// Should never be used
executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(0);
executor.setMaxPoolSize(1);
executor.setQueueCapacity(0);
ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
threadPoolTaskExecutor.setCorePoolSize(0);
threadPoolTaskExecutor.setMaxPoolSize(1);
threadPoolTaskExecutor.setQueueCapacity(0);
return threadPoolTaskExecutor;
});
}
private static TaskExecutor getTaskExecutor(ChannelRegistration registration,
String threadNamePrefix, Supplier<TaskExecutor> fallback) {
return registration.getTaskExecutor(fallback,
executor -> setThreadNamePrefix(executor, threadNamePrefix));
}
private TaskExecutor defaultTaskExecutor() {
return new TaskExecutorRegistration().getTaskExecutor();
}
private static void setThreadNamePrefix(TaskExecutor taskExecutor, String name) {
if (taskExecutor instanceof CustomizableThreadCreator ctc) {
ctc.setThreadNamePrefix(name);
}
executor.setThreadNamePrefix("brokerChannel-");
return executor;
}
/**

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,7 +19,10 @@ package org.springframework.messaging.simp.config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
@ -29,6 +32,7 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
* {@link org.springframework.messaging.MessageChannel}.
*
* @author Rossen Stoyanchev
* @author Stephane Nicoll
* @since 4.0
*/
public class ChannelRegistration {
@ -36,6 +40,9 @@ public class ChannelRegistration {
@Nullable
private TaskExecutorRegistration registration;
@Nullable
private TaskExecutor executor;
private final List<ChannelInterceptor> interceptors = new ArrayList<>();
@ -59,6 +66,18 @@ public class ChannelRegistration {
return this.registration;
}
/**
* Configure the given {@link TaskExecutor} for this message channel,
* taking precedence over a {@linkplain #taskExecutor() task executor
* registration} if any.
* @param taskExecutor the task executor to use
* @since 6.1.4
*/
public ChannelRegistration executor(TaskExecutor taskExecutor) {
this.executor = taskExecutor;
return this;
}
/**
* Configure the given interceptors for this message channel,
* adding them to the channel's current list of interceptors.
@ -71,13 +90,40 @@ public class ChannelRegistration {
protected boolean hasTaskExecutor() {
return (this.registration != null);
return (this.registration != null || this.executor != null);
}
protected boolean hasInterceptors() {
return !this.interceptors.isEmpty();
}
/**
* Return the {@link TaskExecutor} to use. If no task executor has been
* configured, the {@code fallback} supplier is used to provide a fallback
* instance.
* <p>
* If the {@link TaskExecutor} to use is suitable for further customizations,
* the {@code customizer} consumer is invoked.
* @param fallback a supplier of a fallback task executor in case none is configured
* @param customizer further customizations
* @return the task executor to use
*/
protected TaskExecutor getTaskExecutor(Supplier<TaskExecutor> fallback, Consumer<TaskExecutor> customizer) {
if (this.executor != null) {
return this.executor;
}
else if (this.registration != null) {
ThreadPoolTaskExecutor registeredTaskExecutor = this.registration.getTaskExecutor();
customizer.accept(registeredTaskExecutor);
return registeredTaskExecutor;
}
else {
TaskExecutor taskExecutor = fallback.get();
customizer.accept(taskExecutor);
return taskExecutor;
}
}
protected List<ChannelInterceptor> getInterceptors() {
return this.interceptors;
}

View File

@ -0,0 +1,121 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.config;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.junit.jupiter.api.Test;
import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
/**
* Tests for {@link ChannelRegistration}.
*
* @author Stephane Nicoll
*/
class ChannelRegistrationTests {
private final Supplier<TaskExecutor> fallback = mock();
private final Consumer<TaskExecutor> customizer = mock();
@Test
void emptyRegistrationUsesFallback() {
TaskExecutor fallbackTaskExecutor = mock(TaskExecutor.class);
given(this.fallback.get()).willReturn(fallbackTaskExecutor);
ChannelRegistration registration = new ChannelRegistration();
assertThat(registration.hasTaskExecutor()).isFalse();
TaskExecutor actual = registration.getTaskExecutor(this.fallback, this.customizer);
assertThat(actual).isSameAs(fallbackTaskExecutor);
verify(this.fallback).get();
verify(this.customizer).accept(fallbackTaskExecutor);
}
@Test
void emptyRegistrationDoesNotHaveInterceptors() {
ChannelRegistration registration = new ChannelRegistration();
assertThat(registration.hasInterceptors()).isFalse();
assertThat(registration.getInterceptors()).isEmpty();
}
@Test
void taskRegistrationCreatesDefaultInstance() {
ChannelRegistration registration = new ChannelRegistration();
registration.taskExecutor();
assertThat(registration.hasTaskExecutor()).isTrue();
TaskExecutor taskExecutor = registration.getTaskExecutor(this.fallback, this.customizer);
assertThat(taskExecutor).isInstanceOf(ThreadPoolTaskExecutor.class);
verifyNoInteractions(this.fallback);
verify(this.customizer).accept(taskExecutor);
}
@Test
void taskRegistrationWithExistingThreadPoolTaskExecutor() {
ThreadPoolTaskExecutor existingTaskExecutor = mock(ThreadPoolTaskExecutor.class);
ChannelRegistration registration = new ChannelRegistration();
registration.taskExecutor(existingTaskExecutor);
assertThat(registration.hasTaskExecutor()).isTrue();
TaskExecutor taskExecutor = registration.getTaskExecutor(this.fallback, this.customizer);
assertThat(taskExecutor).isSameAs(existingTaskExecutor);
verifyNoInteractions(this.fallback);
verify(this.customizer).accept(taskExecutor);
}
@Test
void configureExecutor() {
ChannelRegistration registration = new ChannelRegistration();
TaskExecutor taskExecutor = mock(TaskExecutor.class);
registration.executor(taskExecutor);
assertThat(registration.hasTaskExecutor()).isTrue();
TaskExecutor taskExecutor1 = registration.getTaskExecutor(this.fallback, this.customizer);
assertThat(taskExecutor1).isSameAs(taskExecutor);
verifyNoInteractions(this.fallback, this.customizer);
}
@Test
void configureExecutorTakesPrecedenceOverTaskRegistration() {
ChannelRegistration registration = new ChannelRegistration();
TaskExecutor taskExecutor = mock(TaskExecutor.class);
registration.executor(taskExecutor);
ThreadPoolTaskExecutor ignored = mock(ThreadPoolTaskExecutor.class);
registration.taskExecutor(ignored);
assertThat(registration.hasTaskExecutor()).isTrue();
assertThat(registration.getTaskExecutor(this.fallback, this.customizer)).isSameAs(taskExecutor);
verifyNoInteractions(ignored, this.fallback, this.customizer);
}
@Test
void configureInterceptors() {
ChannelRegistration registration = new ChannelRegistration();
ChannelInterceptor interceptor1 = mock(ChannelInterceptor.class);
registration.interceptors(interceptor1);
ChannelInterceptor interceptor2 = mock(ChannelInterceptor.class);
registration.interceptors(interceptor2);
assertThat(registration.getInterceptors()).containsExactly(interceptor1, interceptor2);
}
}