Refactor use of TaskScheduler in WebSocket Java config
Issue: SPR-15233
This commit is contained in:
parent
190408d1dc
commit
779779de7b
|
@ -43,8 +43,6 @@ import org.springframework.web.socket.sockjs.transport.handler.WebSocketTranspor
|
|||
*/
|
||||
public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSocketHandlerRegistration {
|
||||
|
||||
private final TaskScheduler sockJsTaskScheduler;
|
||||
|
||||
private final MultiValueMap<WebSocketHandler, String> handlerMap = new LinkedMultiValueMap<>();
|
||||
|
||||
private HandshakeHandler handshakeHandler;
|
||||
|
@ -55,9 +53,21 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
|
|||
|
||||
private SockJsServiceRegistration sockJsServiceRegistration;
|
||||
|
||||
private TaskScheduler scheduler;
|
||||
|
||||
|
||||
public AbstractWebSocketHandlerRegistration() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Deprecated constructor with a TaskScheduler.
|
||||
*
|
||||
* @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until
|
||||
* it is obvious that it is needed, see {@link #getSockJsServiceRegistration()}.
|
||||
*/
|
||||
@Deprecated
|
||||
public AbstractWebSocketHandlerRegistration(TaskScheduler defaultTaskScheduler) {
|
||||
this.sockJsTaskScheduler = defaultTaskScheduler;
|
||||
this.scheduler = defaultTaskScheduler;
|
||||
}
|
||||
|
||||
|
||||
|
@ -98,7 +108,10 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
|
|||
|
||||
@Override
|
||||
public SockJsServiceRegistration withSockJS() {
|
||||
this.sockJsServiceRegistration = new SockJsServiceRegistration(this.sockJsTaskScheduler);
|
||||
this.sockJsServiceRegistration = new SockJsServiceRegistration();
|
||||
if (this.scheduler != null) {
|
||||
this.sockJsServiceRegistration.setTaskScheduler(this.scheduler);
|
||||
}
|
||||
HandshakeInterceptor[] interceptors = getInterceptors();
|
||||
if (interceptors.length > 0) {
|
||||
this.sockJsServiceRegistration.setInterceptors(interceptors);
|
||||
|
@ -121,6 +134,16 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
|
|||
return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Expose the {@code SockJsServiceRegistration} -- if SockJS is enabled or
|
||||
* {@code null} otherwise -- so that it can be configured with a TaskScheduler
|
||||
* if the application did not provide one. This should be done prior to
|
||||
* calling {@link #getMappings()}.
|
||||
*/
|
||||
protected SockJsServiceRegistration getSockJsServiceRegistration() {
|
||||
return this.sockJsServiceRegistration;
|
||||
}
|
||||
|
||||
protected final M getMappings() {
|
||||
M mappings = createMappings();
|
||||
if (this.sockJsServiceRegistration != null) {
|
||||
|
|
|
@ -41,8 +41,19 @@ public class ServletWebSocketHandlerRegistration
|
|||
extends AbstractWebSocketHandlerRegistration<MultiValueMap<HttpRequestHandler, String>> {
|
||||
|
||||
|
||||
public ServletWebSocketHandlerRegistration(TaskScheduler sockJsTaskScheduler) {
|
||||
super(sockJsTaskScheduler);
|
||||
public ServletWebSocketHandlerRegistration() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Deprecated constructor with a TaskScheduler for SockJS use.
|
||||
*
|
||||
* @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until
|
||||
* it is obvious that it is needed, see {@link #getSockJsServiceRegistration()}.
|
||||
*/
|
||||
@Deprecated
|
||||
@SuppressWarnings("deprecated")
|
||||
public ServletWebSocketHandlerRegistration(TaskScheduler scheduler) {
|
||||
super(scheduler);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ import org.springframework.scheduling.TaskScheduler;
|
|||
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.HttpRequestHandler;
|
||||
import org.springframework.web.servlet.HandlerMapping;
|
||||
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
|
||||
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
|
@ -43,21 +42,33 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry
|
|||
|
||||
private final List<ServletWebSocketHandlerRegistration> registrations = new ArrayList<>(4);
|
||||
|
||||
private TaskScheduler sockJsTaskScheduler;
|
||||
private TaskScheduler scheduler;
|
||||
|
||||
private int order = 1;
|
||||
|
||||
private UrlPathHelper urlPathHelper;
|
||||
|
||||
|
||||
public ServletWebSocketHandlerRegistry(ThreadPoolTaskScheduler sockJsTaskScheduler) {
|
||||
this.sockJsTaskScheduler = sockJsTaskScheduler;
|
||||
public ServletWebSocketHandlerRegistry() {
|
||||
this.scheduler = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deprecated constructor with a TaskScheduler for SockJS use.
|
||||
*
|
||||
* @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until
|
||||
* it is obvious that it is needed, see {@link #requiresTaskScheduler()} and
|
||||
* {@link #setTaskScheduler}.
|
||||
*/
|
||||
@Deprecated
|
||||
public ServletWebSocketHandlerRegistry(ThreadPoolTaskScheduler scheduler) {
|
||||
this.scheduler = scheduler;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) {
|
||||
ServletWebSocketHandlerRegistration registration =
|
||||
new ServletWebSocketHandlerRegistration(this.sockJsTaskScheduler);
|
||||
ServletWebSocketHandlerRegistration registration = new ServletWebSocketHandlerRegistration();
|
||||
registration.addHandler(handler, paths);
|
||||
this.registrations.add(registration);
|
||||
return registration;
|
||||
|
@ -88,12 +99,31 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry
|
|||
return this.urlPathHelper;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Return a {@link HandlerMapping} with mapped {@link HttpRequestHandler}s.
|
||||
* Whether there are any endpoint SockJS registrations without a TaskScheduler.
|
||||
* This method should be invoked just before {@link #getHandlerMapping()} to
|
||||
* allow for registrations to be made first.
|
||||
*/
|
||||
protected boolean requiresTaskScheduler() {
|
||||
return this.registrations.stream()
|
||||
.anyMatch(r -> r.getSockJsServiceRegistration() != null &&
|
||||
r.getSockJsServiceRegistration().getTaskScheduler() == null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure a TaskScheduler for SockJS endpoints. This should be configured
|
||||
* before calling {@link #getHandlerMapping()} after checking if
|
||||
* {@link #requiresTaskScheduler()} returns {@code true}.
|
||||
*/
|
||||
protected void setTaskScheduler(TaskScheduler scheduler) {
|
||||
this.scheduler = scheduler;
|
||||
}
|
||||
|
||||
public AbstractHandlerMapping getHandlerMapping() {
|
||||
Map<String, Object> urlMap = new LinkedHashMap<>();
|
||||
for (ServletWebSocketHandlerRegistration registration : this.registrations) {
|
||||
updateTaskScheduler(registration);
|
||||
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
|
||||
for (HttpRequestHandler httpHandler : mappings.keySet()) {
|
||||
for (String pattern : mappings.get(httpHandler)) {
|
||||
|
@ -110,4 +140,11 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry
|
|||
return hm;
|
||||
}
|
||||
|
||||
private void updateTaskScheduler(ServletWebSocketHandlerRegistration registration) {
|
||||
SockJsServiceRegistration sockJsRegistration = registration.getSockJsServiceRegistration();
|
||||
if (sockJsRegistration != null && sockJsRegistration.getTaskScheduler() == null) {
|
||||
sockJsRegistration.setTaskScheduler(this.scheduler);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -70,13 +70,29 @@ public class SockJsServiceRegistration {
|
|||
private SockJsMessageCodec messageCodec;
|
||||
|
||||
|
||||
public SockJsServiceRegistration() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Deprecated constructor with a TaskScheduler.
|
||||
*
|
||||
* @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until
|
||||
* it is obvious that it is needed; call {@link #getTaskScheduler()} to check
|
||||
* and then {@link #setTaskScheduler(TaskScheduler)} to set it before a call
|
||||
* to {@link #createSockJsService()}
|
||||
*/
|
||||
@Deprecated
|
||||
public SockJsServiceRegistration(TaskScheduler defaultTaskScheduler) {
|
||||
this.scheduler = defaultTaskScheduler;
|
||||
}
|
||||
|
||||
|
||||
public SockJsServiceRegistration setTaskScheduler(TaskScheduler taskScheduler) {
|
||||
this.scheduler = taskScheduler;
|
||||
/**
|
||||
* A scheduler instance to use for scheduling SockJS heart-beats.
|
||||
*/
|
||||
public SockJsServiceRegistration setTaskScheduler(TaskScheduler scheduler) {
|
||||
Assert.notNull(scheduler, "TaskScheduler is required.");
|
||||
this.scheduler = scheduler;
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -277,6 +293,13 @@ public class SockJsServiceRegistration {
|
|||
return service;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the TaskScheduler, if configured.
|
||||
*/
|
||||
protected TaskScheduler getTaskScheduler() {
|
||||
return this.scheduler;
|
||||
}
|
||||
|
||||
private TransportHandlingSockJsService createSockJsService() {
|
||||
|
||||
Assert.state(this.transportHandlers.isEmpty() || this.transportHandlerOverrides.isEmpty(),
|
||||
|
|
|
@ -97,7 +97,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
|
|||
|
||||
@Override
|
||||
public SockJsServiceRegistration withSockJS() {
|
||||
this.registration = new SockJsServiceRegistration(this.sockJsTaskScheduler);
|
||||
this.registration = new SockJsServiceRegistration();
|
||||
this.registration.setTaskScheduler(this.sockJsTaskScheduler);
|
||||
HandshakeInterceptor[] interceptors = getInterceptors();
|
||||
if (interceptors.length > 0) {
|
||||
this.registration.setInterceptors(interceptors);
|
||||
|
|
|
@ -16,7 +16,12 @@
|
|||
|
||||
package org.springframework.web.socket.config.annotation;
|
||||
|
||||
import java.util.Date;
|
||||
import java.util.concurrent.ScheduledFuture;
|
||||
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.scheduling.TaskScheduler;
|
||||
import org.springframework.scheduling.Trigger;
|
||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
|
||||
import org.springframework.web.servlet.HandlerMapping;
|
||||
|
||||
|
@ -28,13 +33,28 @@ import org.springframework.web.servlet.HandlerMapping;
|
|||
*/
|
||||
public class WebSocketConfigurationSupport {
|
||||
|
||||
private ServletWebSocketHandlerRegistry handlerRegistry;
|
||||
|
||||
private TaskScheduler scheduler;
|
||||
|
||||
|
||||
@Bean
|
||||
public HandlerMapping webSocketHandlerMapping() {
|
||||
ServletWebSocketHandlerRegistry registry = new ServletWebSocketHandlerRegistry(defaultSockJsTaskScheduler());
|
||||
registerWebSocketHandlers(registry);
|
||||
ServletWebSocketHandlerRegistry registry = initHandlerRegistry();
|
||||
if (registry.requiresTaskScheduler()) {
|
||||
registry.setTaskScheduler(initTaskScheduler());
|
||||
}
|
||||
return registry.getHandlerMapping();
|
||||
}
|
||||
|
||||
private ServletWebSocketHandlerRegistry initHandlerRegistry() {
|
||||
if (this.handlerRegistry == null) {
|
||||
this.handlerRegistry = new ServletWebSocketHandlerRegistry();
|
||||
registerWebSocketHandlers(this.handlerRegistry);
|
||||
}
|
||||
return this.handlerRegistry;
|
||||
}
|
||||
|
||||
protected void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
|
||||
}
|
||||
|
||||
|
@ -55,12 +75,58 @@ public class WebSocketConfigurationSupport {
|
|||
* </pre>
|
||||
*/
|
||||
@Bean
|
||||
public ThreadPoolTaskScheduler defaultSockJsTaskScheduler() {
|
||||
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
|
||||
scheduler.setThreadNamePrefix("SockJS-");
|
||||
scheduler.setPoolSize(Runtime.getRuntime().availableProcessors());
|
||||
scheduler.setRemoveOnCancelPolicy(true);
|
||||
public TaskScheduler defaultSockJsTaskScheduler() {
|
||||
return initTaskScheduler();
|
||||
}
|
||||
|
||||
private TaskScheduler initTaskScheduler() {
|
||||
if (this.scheduler == null) {
|
||||
ServletWebSocketHandlerRegistry registry = initHandlerRegistry();
|
||||
if (registry.requiresTaskScheduler()) {
|
||||
ThreadPoolTaskScheduler threadPoolScheduler = new ThreadPoolTaskScheduler();
|
||||
threadPoolScheduler.setThreadNamePrefix("SockJS-");
|
||||
threadPoolScheduler.setPoolSize(Runtime.getRuntime().availableProcessors());
|
||||
threadPoolScheduler.setRemoveOnCancelPolicy(true);
|
||||
this.scheduler = threadPoolScheduler;
|
||||
}
|
||||
else {
|
||||
this.scheduler = new NoOpScheduler();
|
||||
}
|
||||
}
|
||||
return scheduler;
|
||||
}
|
||||
|
||||
|
||||
private static class NoOpScheduler implements TaskScheduler {
|
||||
|
||||
@Override
|
||||
public ScheduledFuture<?> schedule(Runnable task, Trigger trigger) {
|
||||
throw new IllegalStateException("Unexpected use of scheduler.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScheduledFuture<?> schedule(Runnable task, Date startTime) {
|
||||
throw new IllegalStateException("Unexpected use of scheduler.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScheduledFuture<?> scheduleAtFixedRate(Runnable task, Date startTime, long period) {
|
||||
throw new IllegalStateException("Unexpected use of scheduler.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScheduledFuture<?> scheduleAtFixedRate(Runnable task, long period) {
|
||||
throw new IllegalStateException("Unexpected use of scheduler.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScheduledFuture<?> scheduleWithFixedDelay(Runnable task, Date startTime, long delay) {
|
||||
throw new IllegalStateException("Unexpected use of scheduler.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScheduledFuture<?> scheduleWithFixedDelay(Runnable task, long delay) {
|
||||
throw new IllegalStateException("Unexpected use of scheduler.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,10 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
|
|||
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
|
||||
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertSame;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* Test fixture for
|
||||
|
@ -54,7 +57,7 @@ public class WebSocketHandlerRegistrationTests {
|
|||
@Before
|
||||
public void setup() {
|
||||
this.taskScheduler = Mockito.mock(TaskScheduler.class);
|
||||
this.registration = new TestWebSocketHandlerRegistration(taskScheduler);
|
||||
this.registration = new TestWebSocketHandlerRegistration();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -68,12 +71,14 @@ public class WebSocketHandlerRegistrationTests {
|
|||
Mapping m1 = mappings.get(0);
|
||||
assertEquals(handler, m1.webSocketHandler);
|
||||
assertEquals("/foo", m1.path);
|
||||
assertNotNull(m1.interceptors);
|
||||
assertEquals(1, m1.interceptors.length);
|
||||
assertEquals(OriginHandshakeInterceptor.class, m1.interceptors[0].getClass());
|
||||
|
||||
Mapping m2 = mappings.get(1);
|
||||
assertEquals(handler, m2.webSocketHandler);
|
||||
assertEquals("/bar", m2.path);
|
||||
assertNotNull(m2.interceptors);
|
||||
assertEquals(1, m2.interceptors.length);
|
||||
assertEquals(OriginHandshakeInterceptor.class, m2.interceptors[0].getClass());
|
||||
}
|
||||
|
@ -91,6 +96,7 @@ public class WebSocketHandlerRegistrationTests {
|
|||
Mapping mapping = mappings.get(0);
|
||||
assertEquals(handler, mapping.webSocketHandler);
|
||||
assertEquals("/foo", mapping.path);
|
||||
assertNotNull(mapping.interceptors);
|
||||
assertEquals(2, mapping.interceptors.length);
|
||||
assertEquals(interceptor, mapping.interceptors[0]);
|
||||
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
|
||||
|
@ -109,6 +115,7 @@ public class WebSocketHandlerRegistrationTests {
|
|||
Mapping mapping = mappings.get(0);
|
||||
assertEquals(handler, mapping.webSocketHandler);
|
||||
assertEquals("/foo", mapping.path);
|
||||
assertNotNull(mapping.interceptors);
|
||||
assertEquals(2, mapping.interceptors.length);
|
||||
assertEquals(interceptor, mapping.interceptors[0]);
|
||||
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
|
||||
|
@ -127,6 +134,7 @@ public class WebSocketHandlerRegistrationTests {
|
|||
Mapping mapping = mappings.get(0);
|
||||
assertEquals(handler, mapping.webSocketHandler);
|
||||
assertEquals("/foo", mapping.path);
|
||||
assertNotNull(mapping.interceptors);
|
||||
assertEquals(2, mapping.interceptors.length);
|
||||
assertEquals(interceptor, mapping.interceptors[0]);
|
||||
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
|
||||
|
@ -137,8 +145,12 @@ public class WebSocketHandlerRegistrationTests {
|
|||
WebSocketHandler handler = new TextWebSocketHandler();
|
||||
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
|
||||
|
||||
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor)
|
||||
.setAllowedOrigins("http://mydomain1.com").withSockJS();
|
||||
this.registration.addHandler(handler, "/foo")
|
||||
.addInterceptors(interceptor)
|
||||
.setAllowedOrigins("http://mydomain1.com")
|
||||
.withSockJS();
|
||||
|
||||
this.registration.getSockJsServiceRegistration().setTaskScheduler(this.taskScheduler);
|
||||
|
||||
List<Mapping> mappings = this.registration.getMappings();
|
||||
assertEquals(1, mappings.size());
|
||||
|
@ -175,6 +187,7 @@ public class WebSocketHandlerRegistrationTests {
|
|||
HandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
|
||||
|
||||
this.registration.addHandler(handler, "/foo").setHandshakeHandler(handshakeHandler).withSockJS();
|
||||
this.registration.getSockJsServiceRegistration().setTaskScheduler(this.taskScheduler);
|
||||
|
||||
List<Mapping> mappings = this.registration.getMappings();
|
||||
assertEquals(1, mappings.size());
|
||||
|
@ -190,11 +203,7 @@ public class WebSocketHandlerRegistrationTests {
|
|||
}
|
||||
|
||||
|
||||
private static class TestWebSocketHandlerRegistration extends AbstractWebSocketHandlerRegistration<List<Mapping>> {
|
||||
|
||||
public TestWebSocketHandlerRegistration(TaskScheduler sockJsTaskScheduler) {
|
||||
super(sockJsTaskScheduler);
|
||||
}
|
||||
private static class TestWebSocketHandlerRegistration extends AbstractWebSocketHandlerRegistration<List<Mapping>> {
|
||||
|
||||
@Override
|
||||
protected List<Mapping> createMappings() {
|
||||
|
|
Loading…
Reference in New Issue