diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java index e6823c0d4c..16740af556 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java @@ -25,6 +25,8 @@ import java.util.Random; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -47,7 +49,8 @@ import org.springframework.web.util.UriUtils; * @author Rossen Stoyanchev * @since 4.0 */ -public abstract class AbstractSockJsService implements SockJsService, SockJsConfiguration { +public abstract class AbstractSockJsService + implements SockJsService, SockJsConfiguration, InitializingBean, DisposableBean { protected final Log logger = LogFactory.getLog(getClass()); @@ -64,12 +67,12 @@ public abstract class AbstractSockJsService implements SockJsService, SockJsConf private long heartbeatTime = 25 * 1000; - private TaskScheduler heartbeatScheduler; - private long disconnectDelay = 5 * 1000; private boolean webSocketsEnabled = true; + private final TaskSchedulerHolder heartbeatSchedulerHolder; + /** * Class constructor... @@ -81,14 +84,14 @@ public abstract class AbstractSockJsService implements SockJsService, SockJsConf public AbstractSockJsService(String prefix) { Assert.hasText(prefix, "prefix is required"); this.prefix = prefix; - this.heartbeatScheduler = createScheduler("SockJs-heartbeat-"); + this.heartbeatSchedulerHolder = new TaskSchedulerHolder("SockJs-heartbeat-"); } - protected TaskScheduler createScheduler(String threadNamePrefix) { - ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); - scheduler.setThreadNamePrefix(threadNamePrefix); - scheduler.afterPropertiesSet(); - return scheduler; + public AbstractSockJsService(String prefix, TaskScheduler heartbeatScheduler) { + Assert.hasText(prefix, "prefix is required"); + Assert.notNull(heartbeatScheduler, "heartbeatScheduler is required"); + this.prefix = prefix; + this.heartbeatSchedulerHolder = new TaskSchedulerHolder(heartbeatScheduler); } /** @@ -163,12 +166,7 @@ public abstract class AbstractSockJsService implements SockJsService, SockJsConf } public TaskScheduler getHeartbeatScheduler() { - return this.heartbeatScheduler; - } - - public void setHeartbeatScheduler(TaskScheduler heartbeatScheduler) { - Assert.notNull(heartbeatScheduler, "heartbeatScheduler is required"); - this.heartbeatScheduler = heartbeatScheduler; + return this.heartbeatSchedulerHolder.getScheduler(); } /** @@ -207,6 +205,15 @@ public abstract class AbstractSockJsService implements SockJsService, SockJsConf return this.webSocketsEnabled; } + @Override + public void afterPropertiesSet() throws Exception { + this.heartbeatSchedulerHolder.initialize(); + } + + @Override + public void destroy() throws Exception { + this.heartbeatSchedulerHolder.destroy(); + } /** * TODO @@ -426,4 +433,46 @@ public abstract class AbstractSockJsService implements SockJsService, SockJsConf } }; + + /** + * Holds an externally provided or an internally managed TaskScheduler. Provides + * initialize and destroy methods have no effect if the scheduler is externally + * managed. + */ + protected static class TaskSchedulerHolder { + + private final TaskScheduler taskScheduler; + + private final boolean isDefaultTaskScheduler; + + public TaskSchedulerHolder(TaskScheduler taskScheduler) { + Assert.notNull(taskScheduler, "taskScheduler is required"); + this.taskScheduler = taskScheduler; + this.isDefaultTaskScheduler = false; + } + + public TaskSchedulerHolder(String threadNamePrefix) { + ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); + scheduler.setThreadNamePrefix(threadNamePrefix); + this.taskScheduler = scheduler; + this.isDefaultTaskScheduler = true; + } + + public TaskScheduler getScheduler() { + return this.taskScheduler; + } + + public void initialize() { + if (this.isDefaultTaskScheduler) { + ((ThreadPoolTaskScheduler) this.taskScheduler).afterPropertiesSet(); + } + } + + public void destroy() { + if (this.isDefaultTaskScheduler) { + ((ThreadPoolTaskScheduler) this.taskScheduler).shutdown(); + } + } + } + } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java index af0789aa02..b0f311ec64 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java @@ -29,7 +29,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.scheduling.TaskScheduler; -import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.SockJsSessionFactory; import org.springframework.sockjs.SockJsSessionSupport; @@ -64,7 +63,7 @@ public class DefaultSockJsService extends AbstractSockJsService implements Initi private final Map transportHandlerOverrides = new HashMap(); - private TaskScheduler sessionTimeoutScheduler; + private TaskSchedulerHolder sessionTimeoutSchedulerHolder; private final Map sessions = new ConcurrentHashMap(); @@ -73,21 +72,13 @@ public class DefaultSockJsService extends AbstractSockJsService implements Initi public DefaultSockJsService(String prefix) { super(prefix); - this.sessionTimeoutScheduler = createScheduler("SockJs-sessionTimeout-"); + this.sessionTimeoutSchedulerHolder = new TaskSchedulerHolder("SockJs-sessionTimeout-"); } - /** - * A scheduler instance to use for scheduling periodic expires session cleanup. - *

- * By default a {@link ThreadPoolTaskScheduler} with default settings is used. - */ - public TaskScheduler getSessionTimeoutScheduler() { - return this.sessionTimeoutScheduler; - } - - public void setSessionTimeoutScheduler(TaskScheduler sessionTimeoutScheduler) { + public DefaultSockJsService(String prefix, TaskScheduler heartbeatScheduler, TaskScheduler sessionTimeoutScheduler) { + super(prefix, heartbeatScheduler); Assert.notNull(sessionTimeoutScheduler, "sessionTimeoutScheduler is required"); - this.sessionTimeoutScheduler = sessionTimeoutScheduler; + this.sessionTimeoutSchedulerHolder = new TaskSchedulerHolder(sessionTimeoutScheduler); } public void setTransportHandlers(TransportHandler... handlers) { @@ -124,6 +115,8 @@ public class DefaultSockJsService extends AbstractSockJsService implements Initi @Override public void afterPropertiesSet() throws Exception { + super.afterPropertiesSet(); + if (this.transportHandlers.isEmpty()) { if (isWebSocketEnabled() && (this.transportHandlerOverrides.get(TransportType.WEBSOCKET) == null)) { this.transportHandlers.put(TransportType.WEBSOCKET, @@ -146,7 +139,9 @@ public class DefaultSockJsService extends AbstractSockJsService implements Initi configureTransportHandlers(); - this.sessionTimeoutScheduler.scheduleAtFixedRate(new Runnable() { + this.sessionTimeoutSchedulerHolder.initialize(); + + this.sessionTimeoutSchedulerHolder.getScheduler().scheduleAtFixedRate(new Runnable() { public void run() { try { int count = sessions.size(); @@ -173,6 +168,11 @@ public class DefaultSockJsService extends AbstractSockJsService implements Initi }, getDisconnectDelay()); } + @Override + public void destroy() throws Exception { + super.destroy(); + this.sessionTimeoutSchedulerHolder.destroy(); + } private void configureTransportHandlers() { for (TransportHandler h : this.transportHandlers.values()) { diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/support/TomcatRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/TomcatRequestUpgradeStrategy.java index ac74c9e5f1..5b2f3a2575 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/support/TomcatRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/TomcatRequestUpgradeStrategy.java @@ -30,7 +30,6 @@ import org.apache.tomcat.websocket.server.WsServerContainer; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; -import org.springframework.sockjs.server.NestedSockJsRuntimeException; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import org.springframework.websocket.server.endpoint.EndpointRegistration; @@ -66,7 +65,7 @@ public class TomcatRequestUpgradeStrategy extends AbstractEndpointUpgradeStrateg method.invoke(webSocketRequest); } catch (Exception ex) { - throw new NestedSockJsRuntimeException("Failed to upgrade HttpServletRequest", ex); + throw new IllegalStateException("Failed to upgrade HttpServletRequest", ex); } // TODO: use ServletContext attribute when Tomcat is updated