WebSocketMessageBrokerConfigurer allows to configure Lifecycle phase

Closes gh-32205
This commit is contained in:
rstoyanchev 2024-02-14 16:26:18 +00:00
parent f9ae54d91e
commit 504b7619bd
10 changed files with 238 additions and 21 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -120,6 +120,9 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan
@Nullable
private MessageHeaderInitializer headerInitializer;
@Nullable
private Integer phase;
private volatile boolean running;
private final Object lifecycleMonitor = new Object();
@ -271,6 +274,20 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan
return this.headerInitializer;
}
/**
* Set the phase that this handler should run in.
* <p>By default, this is {@link SmartLifecycle#DEFAULT_PHASE}.
* @since 6.1.4
*/
public void setPhase(int phase) {
this.phase = phase;
}
@Override
public int getPhase() {
return (this.phase != null ? this.phase : SmartLifecycle.super.getPhase());
}
@Override
public final void start() {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -75,6 +75,9 @@ public abstract class AbstractBrokerMessageHandler
private boolean autoStartup = true;
@Nullable
private Integer phase;
private volatile boolean running;
private final Object lifecycleMonitor = new Object();
@ -197,6 +200,20 @@ public abstract class AbstractBrokerMessageHandler
return this.autoStartup;
}
/**
* Set the phase that this handler should run in.
* <p>By default, this is {@link SmartLifecycle#DEFAULT_PHASE}.
* @since 6.1.4
*/
public void setPhase(int phase) {
this.phase = phase;
}
@Override
public int getPhase() {
return (this.phase != null ? this.phase : SmartLifecycle.super.getPhase());
}
@Override
public void start() {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,7 @@ import org.springframework.beans.factory.BeanInitializationException;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.SmartLifecycle;
import org.springframework.context.annotation.Bean;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.core.task.TaskExecutor;
@ -59,6 +60,7 @@ import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ExecutorConfigurationSupport;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.Assert;
@ -132,6 +134,9 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
@Nullable
private MessageBrokerRegistry brokerRegistry;
@Nullable
private Integer phase;
/**
* Protected constructor.
@ -166,8 +171,12 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
@Bean
public TaskExecutor clientInboundChannelExecutor() {
return getTaskExecutor(getClientInboundChannelRegistration(),
"clientInboundChannel-", this::defaultTaskExecutor);
ChannelRegistration registration = getClientInboundChannelRegistration();
TaskExecutor executor = getTaskExecutor(registration, "clientInboundChannel-", this::defaultTaskExecutor);
if (executor instanceof ExecutorConfigurationSupport executorSupport) {
executorSupport.setPhase(getPhase());
}
return executor;
}
protected final ChannelRegistration getClientInboundChannelRegistration() {
@ -180,6 +189,17 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
return this.clientInboundChannelRegistration;
}
protected final int getPhase() {
if (this.phase == null) {
this.phase = initPhase();
}
return this.phase;
}
protected int initPhase() {
return SmartLifecycle.DEFAULT_PHASE;
}
/**
* A hook for subclasses to customize the message channel for inbound messages
* from WebSocket clients.
@ -193,17 +213,21 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(executor);
channel.setLogger(SimpLogging.forLog(channel.getLogger()));
ChannelRegistration reg = getClientOutboundChannelRegistration();
if (reg.hasInterceptors()) {
channel.setInterceptors(reg.getInterceptors());
ChannelRegistration registration = getClientOutboundChannelRegistration();
if (registration.hasInterceptors()) {
channel.setInterceptors(registration.getInterceptors());
}
return channel;
}
@Bean
public TaskExecutor clientOutboundChannelExecutor() {
return getTaskExecutor(getClientOutboundChannelRegistration(),
"clientOutboundChannel-", this::defaultTaskExecutor);
ChannelRegistration registration = getClientOutboundChannelRegistration();
TaskExecutor executor = getTaskExecutor(registration, "clientOutboundChannel-", this::defaultTaskExecutor);
if (executor instanceof ExecutorConfigurationSupport executorSupport) {
executorSupport.setPhase(getPhase());
}
return executor;
}
protected final ChannelRegistration getClientOutboundChannelRegistration() {
@ -244,7 +268,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
MessageBrokerRegistry registry = getBrokerRegistry(clientInboundChannel, clientOutboundChannel);
ChannelRegistration registration = registry.getBrokerChannelRegistration();
return getTaskExecutor(registration, "brokerChannel-", () -> {
TaskExecutor executor = getTaskExecutor(registration, "brokerChannel-", () -> {
// Should never be used
ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
threadPoolTaskExecutor.setCorePoolSize(0);
@ -252,6 +276,10 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
threadPoolTaskExecutor.setQueueCapacity(0);
return threadPoolTaskExecutor;
});
if (executor instanceof ExecutorConfigurationSupport executorSupport) {
executorSupport.setPhase(getPhase());
}
return executor;
}
private TaskExecutor defaultTaskExecutor() {
@ -316,6 +344,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
handler.setDestinationPrefixes(brokerRegistry.getApplicationDestinationPrefixes());
handler.setMessageConverter(brokerMessageConverter);
handler.setValidator(simpValidator());
handler.setPhase(getPhase());
List<HandlerMethodArgumentResolver> argumentResolvers = new ArrayList<>();
addArgumentResolvers(argumentResolvers);
@ -329,6 +358,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
if (pathMatcher != null) {
handler.setPathMatcher(pathMatcher);
}
return handler;
}
@ -342,8 +372,11 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
AbstractSubscribableChannel clientInboundChannel, AbstractSubscribableChannel clientOutboundChannel,
SimpMessagingTemplate brokerMessagingTemplate) {
return new SimpAnnotationMethodMessageHandler(
SimpAnnotationMethodMessageHandler handler = new SimpAnnotationMethodMessageHandler(
clientInboundChannel, clientOutboundChannel, brokerMessagingTemplate);
handler.setPhase(getPhase());
return handler;
}
protected void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
@ -364,6 +397,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
return null;
}
updateUserDestinationResolver(handler, userDestinationResolver, registry.getUserDestinationPrefix());
handler.setPhase(getPhase());
return handler;
}
@ -403,6 +437,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
}
handler.setSystemSubscriptions(subscriptions);
updateUserDestinationResolver(handler, userDestinationResolver, registry.getUserDestinationPrefix());
handler.setPhase(getPhase());
return handler;
}
@ -419,6 +454,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
if (destination != null) {
handler.setBroadcastDestination(destination);
}
handler.setPhase(getPhase());
return handler;
}
@ -446,6 +482,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
scheduler.setThreadNamePrefix("MessageBroker-");
scheduler.setPoolSize(Runtime.getRuntime().availableProcessors());
scheduler.setRemoveOnCancelPolicy(true);
scheduler.setPhase(getPhase());
return scheduler;
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -77,6 +77,9 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
private volatile boolean running;
@Nullable
private Integer phase;
private final Object lifecycleMonitor = new Object();
@ -154,6 +157,20 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
return this.headerInitializer;
}
/**
* Set the phase that this handler should run in.
* <p>By default, this is {@link SmartLifecycle#DEFAULT_PHASE}.
* @since 6.1.4
*/
public void setPhase(int phase) {
this.phase = phase;
}
@Override
public int getPhase() {
return (this.phase != null ? this.phase : SmartLifecycle.super.getPhase());
}
@Override
public final void start() {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -114,4 +114,15 @@ public class DelegatingWebSocketMessageBrokerConfiguration extends WebSocketMess
}
}
@Override
protected int initPhase() {
for (WebSocketMessageBrokerConfigurer configurer : this.configurers) {
Integer phase = configurer.getPhase();
if (phase != null) {
return phase;
}
}
return super.initPhase();
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,12 +35,14 @@ import org.springframework.messaging.simp.user.SimpUserRegistry;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.WebSocketMessageBrokerStats;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import org.springframework.web.socket.messaging.DefaultSimpUserRegistry;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.messaging.WebSocketAnnotationMethodMessageHandler;
import org.springframework.web.socket.server.support.WebSocketHandlerMapping;
/**
* Extends {@link AbstractMessageBrokerConfiguration} and adds configuration for
@ -66,8 +68,11 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
AbstractSubscribableChannel clientInboundChannel,AbstractSubscribableChannel clientOutboundChannel,
SimpMessagingTemplate brokerMessagingTemplate) {
return new WebSocketAnnotationMethodMessageHandler(
WebSocketAnnotationMethodMessageHandler handler = new WebSocketAnnotationMethodMessageHandler(
clientInboundChannel, clientOutboundChannel, brokerMessagingTemplate);
handler.setPhase(getPhase());
return handler;
}
@Override
@ -93,14 +98,22 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
}
registerStompEndpoints(registry);
OrderedMessageChannelDecorator.configureInterceptor(clientInboundChannel, registry.isPreserveReceiveOrder());
return registry.getHandlerMapping();
AbstractHandlerMapping handlerMapping = registry.getHandlerMapping();
if (handlerMapping instanceof WebSocketHandlerMapping webSocketMapping) {
webSocketMapping.setPhase(getPhase());
}
return handlerMapping;
}
@Bean
public WebSocketHandler subProtocolWebSocketHandler(
AbstractSubscribableChannel clientInboundChannel, AbstractSubscribableChannel clientOutboundChannel) {
return new SubProtocolWebSocketHandler(clientInboundChannel, clientOutboundChannel);
SubProtocolWebSocketHandler handler =
new SubProtocolWebSocketHandler(clientInboundChannel, clientOutboundChannel);
handler.setPhase(getPhase());
return handler;
}
protected WebSocketHandler decorateWebSocketHandler(WebSocketHandler handler) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,8 @@ package org.springframework.web.socket.config.annotation;
import java.util.List;
import org.springframework.context.SmartLifecycle;
import org.springframework.lang.Nullable;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler;
@ -110,4 +112,22 @@ public interface WebSocketMessageBrokerConfigurer {
default void configureMessageBroker(MessageBrokerRegistry registry) {
}
/**
* Return the {@link SmartLifecycle#getPhase() phase} that WebSocket message
* handling beans of type {@link SmartLifecycle} should run in.
* <p>The default implementation returns {@link null} which allows other
* configurers to decide. As soon as any configurer returns a value, that
* value is used. If no configurer returns a value, then by default
* {@link SmartLifecycle#DEFAULT_PHASE} is used.
* <p>It is recommended to use a phase value such as 0 in order to ensure that
* components start before the web server in Spring Boot application. In 6.2.0,
* the default used will change to 0.
* @since 6.1.4
* @see SmartLifecycle
*/
@Nullable
default Integer getPhase() {
return null;
}
}

View File

@ -103,6 +103,9 @@ public class SubProtocolWebSocketHandler
private final DefaultStats stats = new DefaultStats();
@Nullable
private Integer phase;
private volatile boolean running;
private final Object lifecycleMonitor = new Object();
@ -249,6 +252,20 @@ public class SubProtocolWebSocketHandler
return this.timeToFirstMessage;
}
/**
* Set the phase that this handler should run in.
* <p>By default, this is {@link SmartLifecycle#DEFAULT_PHASE}.
* @since 6.1.4
*/
public void setPhase(int phase) {
this.phase = phase;
}
@Override
public int getPhase() {
return (this.phase != null ? this.phase : SmartLifecycle.super.getPhase());
}
/**
* Return a String describing internal state and counters.
* Effectively {@code toString()} on {@link #getStats() getStats()}.

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -40,6 +40,9 @@ public class WebSocketHandlerMapping extends SimpleUrlHandlerMapping implements
private boolean webSocketUpgradeMatch;
@Nullable
private Integer phase;
private volatile boolean running;
@ -57,6 +60,20 @@ public class WebSocketHandlerMapping extends SimpleUrlHandlerMapping implements
this.webSocketUpgradeMatch = match;
}
/**
* Set the phase that this handler should run in.
* <p>By default, this is {@link SmartLifecycle#DEFAULT_PHASE}.
* @since 6.1.4
*/
public void setPhase(int phase) {
this.phase = phase;
}
@Override
public int getPhase() {
return (this.phase != null ? this.phase : SmartLifecycle.super.getPhase());
}
@Override
protected void initServletContext(ServletContext servletContext) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.function.Consumer;
import org.junit.jupiter.api.Test;
@ -35,6 +36,7 @@ import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.annotation.SubscribeMapping;
import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler;
import org.springframework.messaging.simp.broker.DefaultSubscriptionRegistry;
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
@ -45,6 +47,7 @@ import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.stereotype.Controller;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
@ -57,6 +60,7 @@ import org.springframework.web.socket.messaging.StompSubProtocolHandler;
import org.springframework.web.socket.messaging.StompTextMessageBuilder;
import org.springframework.web.socket.messaging.SubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.support.WebSocketHandlerMapping;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import static org.assertj.core.api.Assertions.assertThat;
@ -207,6 +211,27 @@ class WebSocketMessageBrokerConfigurationSupportTests {
assertThat(actual).matches(expected);
}
@Test
void lifecyclePhase() {
ApplicationContext context = createContext(LifecyclePhaseConfig.class);
int phase = 99;
Consumer<String> executorTester = beanName ->
assertThat(context.getBean(beanName, ThreadPoolTaskExecutor.class).getPhase()).isEqualTo(phase);
executorTester.accept("clientInboundChannelExecutor");
executorTester.accept("clientOutboundChannelExecutor");
executorTester.accept("brokerChannelExecutor");
assertThat(context.getBean(SimpAnnotationMethodMessageHandler.class).getPhase()).isEqualTo(phase);
assertThat(context.getBean(UserDestinationMessageHandler.class).getPhase()).isEqualTo(phase);
assertThat(context.getBean(ThreadPoolTaskScheduler.class).getPhase()).isEqualTo(phase);
assertThat(context.getBean(WebSocketHandlerMapping.class).getPhase()).isEqualTo(phase);
assertThat(context.getBean(SubProtocolWebSocketHandler.class).getPhase()).isEqualTo(phase);
assertThat(context.getBean(SimpleBrokerMessageHandler.class).getPhase()).isEqualTo(phase);
}
@Test
void webSocketHandlerDecorator() throws Exception {
ApplicationContext context = createContext(WebSocketHandlerDecoratorConfig.class);
@ -347,4 +372,30 @@ class WebSocketMessageBrokerConfigurationSupportTests {
}
}
@Configuration
static class LifecyclePhaseConfig extends DelegatingWebSocketMessageBrokerConfiguration {
@Bean
public WebSocketMessageBrokerConfigurer getConfigurer() {
return new WebSocketMessageBrokerConfigurer() {
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/broker");
}
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.enableSimpleBroker();
}
@Override
public Integer getPhase() {
return 99;
}
};
}
}
}