Add Tomcat WebSocket integration tests

This commit is contained in:
Rossen Stoyanchev 2013-09-01 16:01:15 -04:00
parent e21bbdd933
commit fee3148b0f
12 changed files with 349 additions and 76 deletions

View File

@ -322,10 +322,14 @@ project("spring-messaging") {
optional("org.projectreactor:reactor-core:1.0.0.M2") optional("org.projectreactor:reactor-core:1.0.0.M2")
optional("org.projectreactor:reactor-tcp:1.0.0.M2") optional("org.projectreactor:reactor-tcp:1.0.0.M2")
optional("com.lmax:disruptor:3.1.1") optional("com.lmax:disruptor:3.1.1")
optional("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815")
optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815")
testCompile(project(":spring-test")) testCompile(project(":spring-test"))
testCompile("com.thoughtworks.xstream:xstream:1.4.4") testCompile("com.thoughtworks.xstream:xstream:1.4.4")
testCompile("commons-dbcp:commons-dbcp:1.2.2") testCompile("commons-dbcp:commons-dbcp:1.2.2")
testCompile("javax.inject:javax.inject-tck:1") testCompile("javax.inject:javax.inject-tck:1")
testCompile("javax.servlet:javax.servlet-api:3.1.0")
testCompile("log4j:log4j:1.2.17")
testCompile("org.apache.activemq:activemq-broker:5.8.0") testCompile("org.apache.activemq:activemq-broker:5.8.0")
testCompile("org.apache.activemq:activemq-kahadb-store:5.8.0") { testCompile("org.apache.activemq:activemq-kahadb-store:5.8.0") {
exclude group: "org.springframework", module: "spring-context" exclude group: "org.springframework", module: "spring-context"
@ -334,16 +338,14 @@ project("spring-messaging") {
testCompile("org.eclipse.jetty:jetty-webapp:9.0.5.v20130815") { testCompile("org.eclipse.jetty:jetty-webapp:9.0.5.v20130815") {
exclude group: "org.eclipse.jetty.orbit", module: "javax.servlet" exclude group: "org.eclipse.jetty.orbit", module: "javax.servlet"
} }
optional("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815") testCompile("org.apache.tomcat.embed:tomcat-embed-core:8.0-SNAPSHOT")
optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815") testCompile("org.apache.tomcat.embed:tomcat-embed-logging-juli:8.0-SNAPSHOT")
testCompile("javax.servlet:javax.servlet-api:3.0.1")
testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}")
testCompile("log4j:log4j:1.2.17")
} }
repositories { repositories {
maven { url "https://repository.apache.org/content/repositories/snapshots" } // tomcat-websocket-* snapshots
maven { url 'http://repo.springsource.org/libs-milestone' } // reactor maven { url 'http://repo.springsource.org/libs-milestone' } // reactor
maven { url 'http://repo.springsource.org/libs-snapshot' } // reactor
} }
} }
@ -528,6 +530,7 @@ project("spring-websocket") {
optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815") optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815")
optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") optional("com.fasterxml.jackson.core:jackson-databind:2.2.0")
optional("org.codehaus.jackson:jackson-mapper-asl:1.9.12") optional("org.codehaus.jackson:jackson-mapper-asl:1.9.12")
testCompile("org.apache.tomcat.embed:tomcat-embed-core:8.0-SNAPSHOT")
testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}")
testCompile("log4j:log4j:1.2.17") testCompile("log4j:log4j:1.2.17")
} }

View File

@ -19,6 +19,8 @@ package org.springframework.messaging.simp;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameter;
@ -31,7 +33,9 @@ import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy; import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.support.JettyRequestUpgradeStrategy; import org.springframework.web.socket.server.support.JettyRequestUpgradeStrategy;
import org.springframework.web.socket.server.support.TomcatRequestUpgradeStrategy;
import reactor.util.Assert;
/** /**
@ -41,10 +45,13 @@ import org.springframework.web.socket.server.support.JettyRequestUpgradeStrategy
*/ */
public abstract class AbstractWebSocketIntegrationTests { public abstract class AbstractWebSocketIntegrationTests {
protected Log logger = LogFactory.getLog(getClass());
private static Map<Class<?>, Class<?>> upgradeStrategyConfigTypes = new HashMap<Class<?>, Class<?>>(); private static Map<Class<?>, Class<?>> upgradeStrategyConfigTypes = new HashMap<Class<?>, Class<?>>();
static { static {
upgradeStrategyConfigTypes.put(JettyTestServer.class, JettyUpgradeStrategyConfig.class); upgradeStrategyConfigTypes.put(JettyTestServer.class, JettyUpgradeStrategyConfig.class);
upgradeStrategyConfigTypes.put(TomcatTestServer.class, TomcatUpgradeStrategyConfig.class);
} }
@Parameter(0) @Parameter(0)
@ -59,15 +66,19 @@ public abstract class AbstractWebSocketIntegrationTests {
@Before @Before
public void setup() throws Exception { public void setup() throws Exception {
Class<?> upgradeStrategyConfigClass = upgradeStrategyConfigTypes.get(this.server.getClass());
Assert.notNull(upgradeStrategyConfigClass, "No UpgradeStrategyConfig class");
this.wac = new AnnotationConfigWebApplicationContext(); this.wac = new AnnotationConfigWebApplicationContext();
this.wac.register(getAnnotatedConfigClasses()); this.wac.register(getAnnotatedConfigClasses());
this.wac.register(upgradeStrategyConfigTypes.get(this.server.getClass())); this.wac.register(upgradeStrategyConfigClass);
this.wac.refresh();
if (this.webSocketClient instanceof Lifecycle) { if (this.webSocketClient instanceof Lifecycle) {
((Lifecycle) this.webSocketClient).start(); ((Lifecycle) this.webSocketClient).start();
} }
this.server.init(this.wac); this.server.deployConfig(this.wac);
this.server.start(); this.server.start();
} }
@ -80,9 +91,23 @@ public abstract class AbstractWebSocketIntegrationTests {
((Lifecycle) this.webSocketClient).stop(); ((Lifecycle) this.webSocketClient).stop();
} }
} }
finally { catch (Throwable t) {
logger.error("Failed to stop WebSocket client", t);
}
try {
this.server.undeployConfig();
}
catch (Throwable t) {
logger.error("Failed to undeploy application config", t);
}
try {
this.server.stop(); this.server.stop();
} }
catch (Throwable t) {
logger.error("Failed to stop server", t);
}
} }
protected String getWsBaseUrl() { protected String getWsBaseUrl() {
@ -110,4 +135,13 @@ public abstract class AbstractWebSocketIntegrationTests {
} }
} }
@Configuration
static class TomcatUpgradeStrategyConfig extends AbstractRequestUpgradeStrategyConfig {
@Bean
public RequestUpgradeStrategy requestUpgradeStrategy() {
return new TomcatRequestUpgradeStrategy();
}
}
} }

View File

@ -47,10 +47,16 @@ public class JettyTestServer implements TestServer {
} }
@Override @Override
public void init(WebApplicationContext cxt) { public void deployConfig(WebApplicationContext cxt) {
ServletContextHandler handler = new ServletContextHandler(); ServletContextHandler contextHandler = new ServletContextHandler();
handler.addServlet(new ServletHolder(new DispatcherServlet(cxt)), "/"); ServletHolder servletHolder = new ServletHolder(new DispatcherServlet(cxt));
this.jettyServer.setHandler(handler); contextHandler.addServlet(servletHolder, "/");
this.jettyServer.setHandler(contextHandler);
}
@Override
public void undeployConfig() {
// Stopping jetty will undeploy the servlet
} }
@Override @Override

View File

@ -27,7 +27,9 @@ public interface TestServer {
int getPort(); int getPort();
void init(WebApplicationContext cxt); void deployConfig(WebApplicationContext cxt);
void undeployConfig();
void start() throws Exception; void start() throws Exception;

View File

@ -0,0 +1,112 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp;
import java.io.File;
import java.io.IOException;
import org.apache.catalina.Context;
import org.apache.catalina.connector.Connector;
import org.apache.catalina.startup.Tomcat;
import org.apache.coyote.http11.Http11NioProtocol;
import org.apache.tomcat.util.descriptor.web.ApplicationListener;
import org.apache.tomcat.websocket.server.WsListener;
import org.springframework.core.NestedRuntimeException;
import org.springframework.util.SocketUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
/**
* Tomcat based {@link TestServer}.
*
* @author Rossen Stoyanchev
*/
public class TomcatTestServer implements TestServer {
private static final ApplicationListener WS_APPLICATION_LISTENER =
new ApplicationListener(WsListener.class.getName(), false);
private final Tomcat tomcatServer;
private final int port;
private Context context;
public TomcatTestServer() {
this.port = SocketUtils.findAvailableTcpPort();
Connector connector = new Connector(Http11NioProtocol.class.getName());
connector.setPort(this.port);
File baseDir = createTempDir("tomcat");
String baseDirPath = baseDir.getAbsolutePath();
this.tomcatServer = new Tomcat();
this.tomcatServer.setBaseDir(baseDirPath);
this.tomcatServer.setPort(this.port);
this.tomcatServer.getService().addConnector(connector);
this.tomcatServer.setConnector(connector);
}
private File createTempDir(String prefix) {
try {
File tempFolder = File.createTempFile(prefix + ".", "." + getPort());
tempFolder.delete();
tempFolder.mkdir();
tempFolder.deleteOnExit();
return tempFolder;
}
catch (IOException ex) {
throw new NestedRuntimeException("Unable to create temp directory", ex) {};
}
}
@Override
public int getPort() {
return this.port;
}
@Override
public void deployConfig(WebApplicationContext wac) {
this.context = this.tomcatServer.addContext("", System.getProperty("java.io.tmpdir"));
this.context.addApplicationListener(WS_APPLICATION_LISTENER);
Tomcat.addServlet(context, "dispatcherServlet", new DispatcherServlet(wac));
this.context.addServletMapping("/", "dispatcherServlet");
}
@Override
public void undeployConfig() {
if (this.context != null) {
this.context.removeServletMapping("/");
this.tomcatServer.getHost().removeChild(this.context);
}
}
@Override
public void start() throws Exception {
this.tomcatServer.start();
}
@Override
public void stop() throws Exception {
this.tomcatServer.stop();
}
}

View File

@ -31,6 +31,7 @@ import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.AbstractWebSocketIntegrationTests; import org.springframework.messaging.simp.AbstractWebSocketIntegrationTests;
import org.springframework.messaging.simp.JettyTestServer; import org.springframework.messaging.simp.JettyTestServer;
import org.springframework.messaging.simp.TomcatTestServer;
import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompTextMessageBuilder; import org.springframework.messaging.simp.stomp.StompTextMessageBuilder;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
@ -39,6 +40,7 @@ import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
import org.springframework.web.socket.client.endpoint.StandardWebSocketClient;
import org.springframework.web.socket.client.jetty.JettyWebSocketClient; import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.config.WebSocketConfigurationSupport; import org.springframework.web.socket.server.config.WebSocketConfigurationSupport;
@ -58,7 +60,9 @@ public class WebSocketMessageBrokerConfigurationTests extends AbstractWebSocketI
@Parameters @Parameters
public static Iterable<Object[]> arguments() { public static Iterable<Object[]> arguments() {
return Arrays.asList(new Object[][] { return Arrays.asList(new Object[][] {
{ new JettyTestServer(), new JettyWebSocketClient()} }); {new JettyTestServer(), new JettyWebSocketClient()},
{new TomcatTestServer(), new StandardWebSocketClient()}
});
}; };
@ -82,12 +86,14 @@ public class WebSocketMessageBrokerConfigurationTests extends AbstractWebSocketI
TestController testController = this.wac.getBean(TestController.class); TestController testController = this.wac.getBean(TestController.class);
this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/ws"); WebSocketSession session = this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/ws");
assertTrue(testController.latch.await(2, TimeUnit.SECONDS)); assertTrue(testController.latch.await(2, TimeUnit.SECONDS));
session.close();
testController.latch = new CountDownLatch(1); testController.latch = new CountDownLatch(1);
this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/sockjs/websocket"); session = this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/sockjs/websocket");
assertTrue(testController.latch.await(2, TimeUnit.SECONDS)); assertTrue(testController.latch.await(2, TimeUnit.SECONDS));
session.close();
} }

View File

@ -17,7 +17,6 @@
package org.springframework.web.socket.server.support; package org.springframework.web.socket.server.support;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Method;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
@ -27,17 +26,13 @@ import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.websocket.Endpoint; import javax.websocket.Endpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.apache.tomcat.websocket.server.WsHandshakeRequest;
import org.apache.tomcat.websocket.server.WsHttpUpgradeHandler;
import org.apache.tomcat.websocket.server.WsServerContainer; import org.apache.tomcat.websocket.server.WsServerContainer;
import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration; import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration;
import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean; import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean;
@ -71,17 +66,6 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg
Assert.isTrue(response instanceof ServletServerHttpResponse); Assert.isTrue(response instanceof ServletServerHttpResponse);
HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse(); HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse();
if (hasDoUpgrade) {
doUpgrade(servletRequest, servletResponse, acceptedProtocol, endpoint);
}
else {
upgradeTomcat80RC1(servletRequest, acceptedProtocol, endpoint);
}
}
private void doUpgrade(HttpServletRequest servletRequest, HttpServletResponse servletResponse,
String acceptedProtocol, Endpoint endpoint) {
StringBuffer requestUrl = servletRequest.getRequestURL(); StringBuffer requestUrl = servletRequest.getRequestURL();
String path = servletRequest.getRequestURI(); // shouldn't matter String path = servletRequest.getRequestURI(); // shouldn't matter
Map<String, String> pathParams = Collections.<String, String> emptyMap(); Map<String, String> pathParams = Collections.<String, String> emptyMap();
@ -108,36 +92,4 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg
return (WsServerContainer) servletContext.getAttribute(attribute); return (WsServerContainer) servletContext.getAttribute(attribute);
} }
// FIXME: Remove this after RC2 is out
private void upgradeTomcat80RC1(HttpServletRequest request, String protocol, Endpoint endpoint) {
WsHttpUpgradeHandler upgradeHandler;
try {
upgradeHandler = request.upgrade(WsHttpUpgradeHandler.class);
}
catch (Exception e) {
throw new HandshakeFailureException("Unable to create UpgardeHandler", e);
}
WsHandshakeRequest webSocketRequest = new WsHandshakeRequest(request);
try {
Method method = ReflectionUtils.findMethod(WsHandshakeRequest.class, "finished");
ReflectionUtils.makeAccessible(method);
method.invoke(webSocketRequest);
}
catch (Exception ex) {
throw new HandshakeFailureException("Failed to upgrade HttpServletRequest", ex);
}
ServerEndpointConfig endpointConfig = new ServerEndpointRegistration("/shouldntmatter", endpoint);
upgradeHandler.preInit(endpoint, endpointConfig, getContainer(request), webSocketRequest,
protocol, Collections.<String, String> emptyMap(), request.isSecure());
}
private static boolean hasDoUpgrade = (ReflectionUtils.findMethod(WsServerContainer.class,
"doUpgrade", HttpServletRequest.class, HttpServletResponse.class,
ServerEndpointConfig.class, Map.class) != null);
} }

View File

@ -1,5 +1,4 @@
/* /*
* Copyright 2002-2013 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,6 +18,8 @@ package org.springframework.web.socket;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameter;
@ -31,7 +32,7 @@ import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy; import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.support.JettyRequestUpgradeStrategy; import org.springframework.web.socket.server.support.JettyRequestUpgradeStrategy;
import org.springframework.web.socket.server.support.TomcatRequestUpgradeStrategy;
/** /**
@ -41,10 +42,13 @@ import org.springframework.web.socket.server.support.JettyRequestUpgradeStrategy
*/ */
public abstract class AbstractWebSocketIntegrationTests { public abstract class AbstractWebSocketIntegrationTests {
protected Log logger = LogFactory.getLog(getClass());
private static Map<Class<?>, Class<?>> upgradeStrategyConfigTypes = new HashMap<Class<?>, Class<?>>(); private static Map<Class<?>, Class<?>> upgradeStrategyConfigTypes = new HashMap<Class<?>, Class<?>>();
static { static {
upgradeStrategyConfigTypes.put(JettyTestServer.class, JettyUpgradeStrategyConfig.class); upgradeStrategyConfigTypes.put(JettyTestServer.class, JettyUpgradeStrategyConfig.class);
upgradeStrategyConfigTypes.put(TomcatTestServer.class, TomcatUpgradeStrategyConfig.class);
} }
@Parameter(0) @Parameter(0)
@ -62,12 +66,13 @@ public abstract class AbstractWebSocketIntegrationTests {
this.wac = new AnnotationConfigWebApplicationContext(); this.wac = new AnnotationConfigWebApplicationContext();
this.wac.register(getAnnotatedConfigClasses()); this.wac.register(getAnnotatedConfigClasses());
this.wac.register(upgradeStrategyConfigTypes.get(this.server.getClass())); this.wac.register(upgradeStrategyConfigTypes.get(this.server.getClass()));
this.wac.refresh();
if (this.webSocketClient instanceof Lifecycle) { if (this.webSocketClient instanceof Lifecycle) {
((Lifecycle) this.webSocketClient).start(); ((Lifecycle) this.webSocketClient).start();
} }
this.server.init(this.wac); this.server.deployConfig(this.wac);
this.server.start(); this.server.start();
} }
@ -80,9 +85,23 @@ public abstract class AbstractWebSocketIntegrationTests {
((Lifecycle) this.webSocketClient).stop(); ((Lifecycle) this.webSocketClient).stop();
} }
} }
finally { catch (Throwable t) {
logger.error("Failed to stop WebSocket client", t);
}
try {
this.server.undeployConfig();
}
catch (Throwable t) {
logger.error("Failed to undeploy application config", t);
}
try {
this.server.stop(); this.server.stop();
} }
catch (Throwable t) {
logger.error("Failed to stop server", t);
}
} }
protected String getWsBaseUrl() { protected String getWsBaseUrl() {
@ -110,4 +129,13 @@ public abstract class AbstractWebSocketIntegrationTests {
} }
} }
@Configuration
static class TomcatUpgradeStrategyConfig extends AbstractRequestUpgradeStrategyConfig {
@Bean
public RequestUpgradeStrategy requestUpgradeStrategy() {
return new TomcatRequestUpgradeStrategy();
}
}
} }

View File

@ -47,10 +47,16 @@ public class JettyTestServer implements TestServer {
} }
@Override @Override
public void init(WebApplicationContext cxt) { public void deployConfig(WebApplicationContext cxt) {
ServletContextHandler handler = new ServletContextHandler(); ServletContextHandler contextHandler = new ServletContextHandler();
handler.addServlet(new ServletHolder(new DispatcherServlet(cxt)), "/"); ServletHolder servletHolder = new ServletHolder(new DispatcherServlet(cxt));
this.jettyServer.setHandler(handler); contextHandler.addServlet(servletHolder, "/");
this.jettyServer.setHandler(contextHandler);
}
@Override
public void undeployConfig() {
// Stopping jetty will undeploy the servlet
} }
@Override @Override

View File

@ -27,7 +27,9 @@ public interface TestServer {
int getPort(); int getPort();
void init(WebApplicationContext cxt); void deployConfig(WebApplicationContext cxt);
void undeployConfig();
void start() throws Exception; void start() throws Exception;

View File

@ -0,0 +1,112 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket;
import java.io.File;
import java.io.IOException;
import org.apache.catalina.Context;
import org.apache.catalina.connector.Connector;
import org.apache.catalina.startup.Tomcat;
import org.apache.coyote.http11.Http11NioProtocol;
import org.apache.tomcat.util.descriptor.web.ApplicationListener;
import org.apache.tomcat.websocket.server.WsListener;
import org.springframework.core.NestedRuntimeException;
import org.springframework.util.Assert;
import org.springframework.util.SocketUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
/**
* Tomcat based {@link TestServer}.
*
* @author Rossen Stoyanchev
*/
public class TomcatTestServer implements TestServer {
private static final ApplicationListener WS_APPLICATION_LISTENER =
new ApplicationListener(WsListener.class.getName(), false);
private final Tomcat tomcatServer;
private final int port;
private Context context;
public TomcatTestServer() {
this.port = SocketUtils.findAvailableTcpPort();
Connector connector = new Connector(Http11NioProtocol.class.getName());
connector.setPort(this.port);
File baseDir = createTempDir("tomcat");
String baseDirPath = baseDir.getAbsolutePath();
this.tomcatServer = new Tomcat();
this.tomcatServer.setBaseDir(baseDirPath);
this.tomcatServer.setPort(this.port);
this.tomcatServer.getService().addConnector(connector);
this.tomcatServer.setConnector(connector);
}
private File createTempDir(String prefix) {
try {
File tempFolder = File.createTempFile(prefix + ".", "." + getPort());
tempFolder.delete();
tempFolder.mkdir();
tempFolder.deleteOnExit();
return tempFolder;
}
catch (IOException ex) {
throw new NestedRuntimeException("Unable to create temp directory", ex) {};
}
}
@Override
public int getPort() {
return this.port;
}
@Override
public void deployConfig(WebApplicationContext wac) {
this.context = this.tomcatServer.addContext("", System.getProperty("java.io.tmpdir"));
this.context.addApplicationListener(WS_APPLICATION_LISTENER);
Tomcat.addServlet(context, "dispatcherServlet", new DispatcherServlet(wac));
this.context.addServletMapping("/", "dispatcherServlet");
}
@Override
public void undeployConfig() {
Assert.notNull(this.context, "deployConfig/undeployConfig must be invoked in pairs");
this.context.removeServletMapping("/");
this.tomcatServer.getHost().removeChild(this.context);
}
@Override
public void start() throws Exception {
this.tomcatServer.start();
}
@Override
public void stop() throws Exception {
this.tomcatServer.stop();
}
}

View File

@ -29,8 +29,10 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests; import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
import org.springframework.web.socket.JettyTestServer; import org.springframework.web.socket.JettyTestServer;
import org.springframework.web.socket.TomcatTestServer;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; import org.springframework.web.socket.adapter.WebSocketHandlerAdapter;
import org.springframework.web.socket.client.endpoint.StandardWebSocketClient;
import org.springframework.web.socket.client.jetty.JettyWebSocketClient; import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
@ -49,7 +51,9 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes
@Parameters @Parameters
public static Iterable<Object[]> arguments() { public static Iterable<Object[]> arguments() {
return Arrays.asList(new Object[][] { return Arrays.asList(new Object[][] {
{ new JettyTestServer(), new JettyWebSocketClient()} }); {new JettyTestServer(), new JettyWebSocketClient()},
{new TomcatTestServer(), new StandardWebSocketClient()}
});
}; };
@ -61,19 +65,25 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes
@Test @Test
public void registerWebSocketHandler() throws Exception { public void registerWebSocketHandler() throws Exception {
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/ws"); WebSocketSession session =
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/ws");
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class); TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS)); assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
session.close();
} }
@Test @Test
public void registerWebSocketHandlerWithSockJS() throws Exception { public void registerWebSocketHandlerWithSockJS() throws Exception {
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/sockjs/websocket"); WebSocketSession session =
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/sockjs/websocket");
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class); TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS)); assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
session.close();
} }