Add WebSocket integration tests w/ Java configuration

Issue: SPR-10835
This commit is contained in:
Rossen Stoyanchev 2013-08-28 21:08:17 -04:00
parent 744e1ed203
commit f0dda0e38b
20 changed files with 686 additions and 295 deletions

View File

@ -334,7 +334,8 @@ project("spring-messaging") {
testCompile("org.eclipse.jetty:jetty-webapp:9.0.5.v20130815") { testCompile("org.eclipse.jetty:jetty-webapp:9.0.5.v20130815") {
exclude group: "org.eclipse.jetty.orbit", module: "javax.servlet" exclude group: "org.eclipse.jetty.orbit", module: "javax.servlet"
} }
testCompile("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815") optional("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815")
optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815")
testCompile("javax.servlet:javax.servlet-api:3.0.1") testCompile("javax.servlet:javax.servlet-api:3.0.1")
testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}")
testCompile("log4j:log4j:1.2.17") testCompile("log4j:log4j:1.2.17")
@ -527,6 +528,8 @@ project("spring-websocket") {
optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815") optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815")
optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") optional("com.fasterxml.jackson.core:jackson-databind:2.2.0")
optional("org.codehaus.jackson:jackson-mapper-asl:1.9.12") optional("org.codehaus.jackson:jackson-mapper-asl:1.9.12")
testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}")
testCompile("log4j:log4j:1.2.17")
} }
repositories { repositories {

View File

@ -72,6 +72,7 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHan
this.outputChannel = outputChannel; this.outputChannel = outputChannel;
} }
/** /**
* Configure one or more handlers to use depending on the sub-protocol requested by * Configure one or more handlers to use depending on the sub-protocol requested by
* the client in the WebSocket handshake request. * the client in the WebSocket handshake request.
@ -130,6 +131,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHan
return this.defaultProtocolHandler; return this.defaultProtocolHandler;
} }
/**
* Return all supported protocols.
*/
public Set<String> getSupportedProtocols() {
return this.protocolHandlers.keySet();
}
@Override @Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception { public void afterConnectionEstablished(WebSocketSession session) throws Exception {

View File

@ -19,11 +19,13 @@ package org.springframework.messaging.simp.config;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Set;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler; import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.util.ObjectUtils;
import org.springframework.web.HttpRequestHandler; import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.server.DefaultHandshakeHandler; import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler;
@ -46,53 +48,43 @@ public class StompEndpointRegistration {
private final SubProtocolWebSocketHandler wsHandler; private final SubProtocolWebSocketHandler wsHandler;
private HandshakeHandler handshakeHandler;
private StompSockJsServiceRegistration sockJsServiceRegistration; private StompSockJsServiceRegistration sockJsServiceRegistration;
private TaskScheduler defaultTaskScheduler; private final TaskScheduler defaultSockJsTaskScheduler;
public StompEndpointRegistration(Collection<String> paths, SubProtocolWebSocketHandler webSocketHandler) { public StompEndpointRegistration(Collection<String> paths, SubProtocolWebSocketHandler webSocketHandler,
TaskScheduler defaultSockJsTaskScheduler) {
this.paths = new ArrayList<String>(paths); this.paths = new ArrayList<String>(paths);
this.wsHandler = webSocketHandler; this.wsHandler = webSocketHandler;
this.defaultSockJsTaskScheduler = defaultSockJsTaskScheduler;
} }
protected List<String> getPaths() { public StompEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) {
return this.paths; this.handshakeHandler = handshakeHandler;
} return this;
protected SubProtocolWebSocketHandler getSubProtocolWebSocketHandler() {
return this.wsHandler;
}
protected StompSockJsServiceRegistration getSockJsServiceRegistration() {
return this.sockJsServiceRegistration;
} }
public SockJsServiceRegistration withSockJS() { public SockJsServiceRegistration withSockJS() {
this.sockJsServiceRegistration = new StompSockJsServiceRegistration(this.defaultTaskScheduler); this.sockJsServiceRegistration = new StompSockJsServiceRegistration(this.defaultSockJsTaskScheduler);
return this.sockJsServiceRegistration; return this.sockJsServiceRegistration;
} }
protected void setDefaultTaskScheduler(TaskScheduler defaultTaskScheduler) {
this.defaultTaskScheduler = defaultTaskScheduler;
}
protected TaskScheduler getDefaultTaskScheduler() {
return this.defaultTaskScheduler;
}
protected MultiValueMap<HttpRequestHandler, String> getMappings() { protected MultiValueMap<HttpRequestHandler, String> getMappings() {
MultiValueMap<HttpRequestHandler, String> mappings = new LinkedMultiValueMap<HttpRequestHandler, String>(); MultiValueMap<HttpRequestHandler, String> mappings = new LinkedMultiValueMap<HttpRequestHandler, String>();
if (getSockJsServiceRegistration() == null) { if (this.sockJsServiceRegistration == null) {
HandshakeHandler handshakeHandler = createHandshakeHandler(); HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler();
for (String path : getPaths()) { for (String path : this.paths) {
WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(this.wsHandler, handshakeHandler); WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(this.wsHandler, handshakeHandler);
mappings.add(handler, path); mappings.add(handler, path);
} }
} }
else { else {
SockJsService sockJsService = getSockJsServiceRegistration().getSockJsService(); SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService();
for (String path : this.paths) { for (String path : this.paths) {
SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, this.wsHandler); SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, this.wsHandler);
mappings.add(httpHandler, path.endsWith("/") ? path + "**" : path + "/**"); mappings.add(httpHandler, path.endsWith("/") ? path + "**" : path + "/**");
@ -101,8 +93,20 @@ public class StompEndpointRegistration {
return mappings; return mappings;
} }
protected DefaultHandshakeHandler createHandshakeHandler() { private HandshakeHandler getOrCreateHandshakeHandler() {
return new DefaultHandshakeHandler();
HandshakeHandler handler = (this.handshakeHandler != null)
? this.handshakeHandler : new DefaultHandshakeHandler();
if (handler instanceof DefaultHandshakeHandler) {
DefaultHandshakeHandler defaultHandshakeHandler = (DefaultHandshakeHandler) handler;
if (ObjectUtils.isEmpty(defaultHandshakeHandler.getSupportedProtocols())) {
Set<String> protocols = this.wsHandler.getSupportedProtocols();
defaultHandshakeHandler.setSupportedProtocols(protocols.toArray(new String[protocols.size()]));
}
}
return handler;
} }
@ -114,7 +118,7 @@ public class StompEndpointRegistration {
} }
protected SockJsService getSockJsService() { protected SockJsService getSockJsService() {
return super.getSockJsService(getPaths().toArray(new String[getPaths().size()])); return super.getSockJsService(paths.toArray(new String[paths.size()]));
} }
} }

View File

@ -51,11 +51,11 @@ public class StompEndpointRegistry {
private int order = 1; private int order = 1;
private TaskScheduler defaultTaskScheduler; private final TaskScheduler defaultSockJsTaskScheduler;
public StompEndpointRegistry(SubProtocolWebSocketHandler webSocketHandler, public StompEndpointRegistry(SubProtocolWebSocketHandler webSocketHandler,
MutableUserQueueSuffixResolver userQueueSuffixResolver) { MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) {
Assert.notNull(webSocketHandler); Assert.notNull(webSocketHandler);
Assert.notNull(userQueueSuffixResolver); Assert.notNull(userQueueSuffixResolver);
@ -63,25 +63,18 @@ public class StompEndpointRegistry {
this.wsHandler = webSocketHandler; this.wsHandler = webSocketHandler;
this.stompHandler = new StompProtocolHandler(); this.stompHandler = new StompProtocolHandler();
this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver); this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver);
this.defaultSockJsTaskScheduler = defaultSockJsTaskScheduler;
} }
public StompEndpointRegistration addEndpoint(String... paths) { public StompEndpointRegistration addEndpoint(String... paths) {
this.wsHandler.addProtocolHandler(this.stompHandler); this.wsHandler.addProtocolHandler(this.stompHandler);
StompEndpointRegistration r = new StompEndpointRegistration(Arrays.asList(paths), this.wsHandler); StompEndpointRegistration r = new StompEndpointRegistration(
r.setDefaultTaskScheduler(getDefaultTaskScheduler()); Arrays.asList(paths), this.wsHandler, this.defaultSockJsTaskScheduler);
this.registrations.add(r); this.registrations.add(r);
return r; return r;
} }
protected SubProtocolWebSocketHandler getSubProtocolWebSocketHandler() {
return this.wsHandler;
}
protected StompProtocolHandler getStompProtocolHandler() {
return this.stompHandler;
}
/** /**
* Specify the order to use for the STOMP endpoint {@link HandlerMapping} relative to * Specify the order to use for the STOMP endpoint {@link HandlerMapping} relative to
* other handler mappings configured in the Spring MVC configuration. The default * other handler mappings configured in the Spring MVC configuration. The default
@ -91,18 +84,6 @@ public class StompEndpointRegistry {
this.order = order; this.order = order;
} }
protected int getOrder() {
return this.order;
}
protected void setDefaultTaskScheduler(TaskScheduler defaultTaskScheduler) {
this.defaultTaskScheduler = defaultTaskScheduler;
}
protected TaskScheduler getDefaultTaskScheduler() {
return this.defaultTaskScheduler;
}
/** /**
* Returns a handler mapping with the mapped ViewControllers; or {@code null} in case of no registrations. * Returns a handler mapping with the mapped ViewControllers; or {@code null} in case of no registrations.
*/ */

View File

@ -33,6 +33,7 @@ import org.springframework.messaging.support.converter.MessageConverter;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.socket.server.config.SockJsServiceRegistration;
/** /**
@ -54,9 +55,8 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
@Bean @Bean
public HandlerMapping brokerWebSocketHandlerMapping() { public HandlerMapping brokerWebSocketHandlerMapping() {
StompEndpointRegistry registry = StompEndpointRegistry registry = new StompEndpointRegistry(
new StompEndpointRegistry(subProtocolWebSocketHandler(), userQueueSuffixResolver()); subProtocolWebSocketHandler(), userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler());
registry.setDefaultTaskScheduler(brokerDefaultSockJsTaskScheduler());
registerStompEndpoints(registry); registerStompEndpoints(registry);
return registry.getHandlerMapping(); return registry.getHandlerMapping();
} }
@ -73,11 +73,14 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
return new SimpleUserQueueSuffixResolver(); return new SimpleUserQueueSuffixResolver();
} }
/**
* The default TaskScheduler to use if none is configured via
* {@link SockJsServiceRegistration#setTaskScheduler()}
*/
@Bean @Bean
public ThreadPoolTaskScheduler brokerDefaultSockJsTaskScheduler() { public ThreadPoolTaskScheduler brokerDefaultSockJsTaskScheduler() {
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
scheduler.setThreadNamePrefix("BrokerSockJS-"); scheduler.setThreadNamePrefix("BrokerSockJS-");
scheduler.setPoolSize(10);
return scheduler; return scheduler;
} }
@ -97,9 +100,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
@Bean @Bean
public ThreadPoolTaskExecutor webSocketChannelExecutor() { public ThreadPoolTaskExecutor webSocketChannelExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(4); executor.setThreadNamePrefix("BrokerWebSocketChannel-");
executor.setCorePoolSize(8);
executor.setThreadNamePrefix("MessageChannel-");
return executor; return executor;
} }

View File

@ -0,0 +1,104 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp;
import java.util.HashMap;
import java.util.Map;
import org.junit.After;
import org.junit.Before;
import org.junit.runners.Parameterized.Parameter;
import org.springframework.context.Lifecycle;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.support.JettyRequestUpgradeStrategy;
/**
* Base class for WebSocket integration tests.
*
* @author Rossen Stoyanchev
*/
public abstract class AbstractWebSocketIntegrationTests {
private static Map<Class<?>, Class<?>> upgradeStrategyConfigTypes = new HashMap<Class<?>, Class<?>>();
static {
upgradeStrategyConfigTypes.put(JettyTestServer.class, JettyUpgradeStrategyConfig.class);
}
@Parameter(0)
public TestServer server;
@Parameter(1)
public WebSocketClient webSocketClient;
@Before
public void setup() throws Exception {
if (this.webSocketClient instanceof Lifecycle) {
((Lifecycle) this.webSocketClient).start();
}
}
@After
public void teardown() throws Exception {
try {
if (this.webSocketClient instanceof Lifecycle) {
((Lifecycle) this.webSocketClient).stop();
}
}
finally {
this.server.stop();
}
}
protected String getWsBaseUrl() {
return "ws://localhost:" + this.server.getPort();
}
protected Class<?> getUpgradeStrategyConfigClass() {
return upgradeStrategyConfigTypes.get(this.server.getClass());
}
static abstract class AbstractRequestUpgradeStrategyConfig {
@Bean
public HandshakeHandler handshakeHandler() {
return new DefaultHandshakeHandler(requestUpgradeStrategy());
}
public abstract RequestUpgradeStrategy requestUpgradeStrategy();
}
@Configuration
static class JettyUpgradeStrategyConfig extends AbstractRequestUpgradeStrategyConfig {
@Bean
public RequestUpgradeStrategy requestUpgradeStrategy() {
return new JettyRequestUpgradeStrategy();
}
}
}

View File

@ -0,0 +1,69 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.springframework.util.SocketUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.socket.TestServer;
/**
* Jetty based {@link TestServer}.
*
* @author Rossen Stoyanchev
*/
public class JettyTestServer implements TestServer {
private final Server jettyServer;
private final int port;
public JettyTestServer() {
this.port = SocketUtils.findAvailableTcpPort();
this.jettyServer = new Server(this.port);
}
@Override
public int getPort() {
return this.port;
}
@Override
public void init(WebApplicationContext cxt) {
ServletContextHandler handler = new ServletContextHandler();
handler.addServlet(new ServletHolder(new DispatcherServlet(cxt)), "/");
this.jettyServer.setHandler(handler);
}
@Override
public void start() throws Exception {
this.jettyServer.start();
}
@Override
public void stop() throws Exception {
if (this.jettyServer.isRunning()) {
this.jettyServer.stop();
}
}
}

View File

@ -0,0 +1,36 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp;
import org.springframework.web.context.WebApplicationContext;
/**
* Contract for a test server to use for integration tests.
*
* @author Rossen Stoyanchev
*/
public interface TestServer {
int getPort();
void init(WebApplicationContext cxt);
void start() throws Exception;
void stop() throws Exception;
}

View File

@ -16,28 +16,37 @@
package org.springframework.messaging.simp.config; package org.springframework.messaging.simp.config;
import org.junit.Before; import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.Test; import org.junit.Test;
import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler; import org.springframework.messaging.simp.AbstractWebSocketIntegrationTests;
import org.springframework.messaging.simp.JettyTestServer;
import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompMessageConverter; import org.springframework.messaging.simp.stomp.StompMessageConverter;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.config.WebSocketConfigurationSupport; import org.springframework.web.socket.server.config.WebSocketConfigurationSupport;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler;
import org.springframework.web.socket.support.TestWebSocketSession;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -47,65 +56,47 @@ import static org.junit.Assert.*;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class WebSocketMessageBrokerConfigurationTests { @RunWith(Parameterized.class)
public class WebSocketMessageBrokerConfigurationTests extends AbstractWebSocketIntegrationTests {
@Parameters
public static Iterable<Object[]> arguments() {
return Arrays.asList(new Object[][] {
{ new JettyTestServer(), new JettyWebSocketClient()} });
};
@Before
public void setup() {
}
@Test @Test
public void webSocketHandler() throws Exception { public void sendMessage() throws Exception {
AnnotationConfigApplicationContext cxt = new AnnotationConfigApplicationContext(); AnnotationConfigWebApplicationContext cxt = new AnnotationConfigWebApplicationContext();
cxt.register(TestWebSocketMessageBrokerConfiguration.class, SimpleBrokerConfigurer.class); cxt.register(TestWebSocketMessageBrokerConfiguration.class, SimpleBrokerConfigurer.class);
cxt.refresh(); cxt.register(getUpgradeStrategyConfigClass());
SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) cxt.getBean(HandlerMapping.class); this.server.init(cxt);
Object actual = hm.getUrlMap().get("/e1"); this.server.start();
assertNotNull(actual);
assertEquals(WebSocketHttpRequestHandler.class, actual.getClass());
cxt.close();
}
@Test
public void webSocketHandlerWithSockJS() throws Exception {
AnnotationConfigApplicationContext cxt = new AnnotationConfigApplicationContext();
cxt.register(TestWebSocketMessageBrokerConfiguration.class, SimpleBrokerConfigurer.class);
cxt.refresh();
SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) cxt.getBean(HandlerMapping.class);
Object actual = hm.getUrlMap().get("/e2/**");
assertNotNull(actual);
assertEquals(SockJsHttpRequestHandler.class, actual.getClass());
cxt.close();
}
@Test
public void annotationMethodMessageHandler() throws Exception {
AnnotationConfigApplicationContext cxt = new AnnotationConfigApplicationContext();
cxt.register(TestWebSocketMessageBrokerConfiguration.class, SimpleBrokerConfigurer.class);
cxt.refresh();
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setDestination("/app/foo"); headers.setDestination("/app/foo");
Message<byte[]> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); Message<byte[]> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
byte[] bytes = new StompMessageConverter().fromMessage(message); byte[] bytes = new StompMessageConverter().fromMessage(message);
final TextMessage webSocketMessage = new TextMessage(new String(bytes));
TestWebSocketSession session = new TestWebSocketSession(); WebSocketHandler clientHandler = new TextWebSocketHandlerAdapter() {
session.setAcceptedProtocol("v12.stomp"); @Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
session.sendMessage(webSocketMessage);
}
};
SubProtocolWebSocketHandler wsHandler = cxt.getBean(SubProtocolWebSocketHandler.class); TestController testController = cxt.getBean(TestController.class);
wsHandler.handleMessage(session, new TextMessage(new String(bytes)));
assertTrue(cxt.getBean(TestController.class).foo); this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/ws");
assertTrue(testController.latch.await(2, TimeUnit.SECONDS));
cxt.close(); testController.latch = new CountDownLatch(1);
this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/sockjs/websocket");
assertTrue(testController.latch.await(2, TimeUnit.SECONDS));
} }
@ -128,16 +119,23 @@ public class WebSocketMessageBrokerConfigurationTests {
public TestController testController() { public TestController testController() {
return new TestController(); return new TestController();
} }
} }
@Configuration @Configuration
static class SimpleBrokerConfigurer implements WebSocketMessageBrokerConfigurer { static class SimpleBrokerConfigurer implements WebSocketMessageBrokerConfigurer {
@Autowired
private HandshakeHandler handshakeHandler; // can't rely on classpath for server detection
@Override @Override
public void registerStompEndpoints(StompEndpointRegistry registry) { public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/e1");
registry.addEndpoint("/e2").withSockJS(); registry.addEndpoint("/ws")
.setHandshakeHandler(this.handshakeHandler);
registry.addEndpoint("/sockjs").withSockJS()
.setTransportHandlerOverrides(new WebSocketTransportHandler(this.handshakeHandler));;
} }
@Override @Override
@ -150,12 +148,11 @@ public class WebSocketMessageBrokerConfigurationTests {
@Controller @Controller
private static class TestController { private static class TestController {
private boolean foo; private CountDownLatch latch = new CountDownLatch(1);
@MessageMapping(value="/app/foo") @MessageMapping(value="/app/foo")
public void handleFoo() { public void handleFoo() {
this.foo = true; this.latch.countDown();
} }
} }

View File

@ -19,6 +19,10 @@
<level value="info" /> <level value="info" />
</logger> </logger>
<logger name="org.springframework.web">
<level value="debug" />
</logger>
<!-- Root Logger --> <!-- Root Logger -->
<root> <root>
<priority value="warn" /> <priority value="warn" />

View File

@ -24,6 +24,7 @@ import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.TransportHandler;
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
@ -53,7 +54,11 @@ public class SockJsServiceRegistration {
private Boolean webSocketEnabled; private Boolean webSocketEnabled;
private final List<HandshakeInterceptor> handshakeInterceptors = new ArrayList<HandshakeInterceptor>(); private final List<TransportHandler> transportHandlers = new ArrayList<TransportHandler>();
private final List<TransportHandler> transportHandlerOverrides = new ArrayList<TransportHandler>();
private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
public SockJsServiceRegistration(TaskScheduler defaultTaskScheduler) { public SockJsServiceRegistration(TaskScheduler defaultTaskScheduler) {
@ -66,10 +71,6 @@ public class SockJsServiceRegistration {
return this; return this;
} }
protected TaskScheduler getTaskScheduler() {
return this.taskScheduler;
}
/** /**
* Transports which don't support cross-domain communication natively (e.g. * Transports which don't support cross-domain communication natively (e.g.
* "eventsource", "htmlfile") rely on serving a simple page (using the * "eventsource", "htmlfile") rely on serving a simple page (using the
@ -87,14 +88,6 @@ public class SockJsServiceRegistration {
return this; return this;
} }
/**
* The URL to the SockJS JavaScript client library.
* @see #setSockJsClientLibraryUrl(String)
*/
protected String getClientLibraryUrl() {
return this.clientLibraryUrl;
}
/** /**
* Streaming transports save responses on the client side and don't free * Streaming transports save responses on the client side and don't free
* memory used by delivered messages. Such transports need to recycle the * memory used by delivered messages. Such transports need to recycle the
@ -111,10 +104,6 @@ public class SockJsServiceRegistration {
return this; return this;
} }
protected Integer getStreamBytesLimit() {
return this.streamBytesLimit;
}
/** /**
* Some load balancers do sticky sessions, but only if there is a "JSESSIONID" * Some load balancers do sticky sessions, but only if there is a "JSESSIONID"
* cookie. Even if it is set to a dummy value, it doesn't matter since * cookie. Even if it is set to a dummy value, it doesn't matter since
@ -127,14 +116,6 @@ public class SockJsServiceRegistration {
return this; return this;
} }
/**
* Whether setting JSESSIONID cookie is necessary.
* @see #setDummySessionCookieEnabled(boolean)
*/
protected Boolean getDummySessionCookieEnabled() {
return this.sessionCookieEnabled;
}
/** /**
* The amount of time in milliseconds when the server has not sent any * The amount of time in milliseconds when the server has not sent any
* messages and after which the server should send a heartbeat frame to the * messages and after which the server should send a heartbeat frame to the
@ -147,10 +128,6 @@ public class SockJsServiceRegistration {
return this; return this;
} }
protected Long getHeartbeatTime() {
return this.heartbeatTime;
}
/** /**
* The amount of time in milliseconds before a client is considered * The amount of time in milliseconds before a client is considered
* disconnected after not having a receiving connection, i.e. an active * disconnected after not having a receiving connection, i.e. an active
@ -163,13 +140,6 @@ public class SockJsServiceRegistration {
return this; return this;
} }
/**
* Return the amount of time in milliseconds before a client is considered disconnected.
*/
protected Long getDisconnectDelay() {
return this.disconnectDelay;
}
/** /**
* The number of server-to-client messages that a session can cache while waiting for * The number of server-to-client messages that a session can cache while waiting for
* the next HTTP polling request from the client. All HTTP transports use this * the next HTTP polling request from the client. All HTTP transports use this
@ -186,13 +156,6 @@ public class SockJsServiceRegistration {
return this; return this;
} }
/**
* Return the size of the HTTP message cache.
*/
protected Integer getHttpMessageCacheSize() {
return this.httpMessageCacheSize;
}
/** /**
* Some load balancers don't support WebSocket. This option can be used to * Some load balancers don't support WebSocket. This option can be used to
* disable the WebSocket transport on the server side. * disable the WebSocket transport on the server side.
@ -204,23 +167,27 @@ public class SockJsServiceRegistration {
return this; return this;
} }
/** public SockJsServiceRegistration setTransportHandlers(TransportHandler... handlers) {
* Whether WebSocket transport is enabled. this.transportHandlers.clear();
* @see #setWebSocketsEnabled(boolean) if (!ObjectUtils.isEmpty(handlers)) {
*/ this.transportHandlers.addAll(Arrays.asList(handlers));
protected Boolean getWebSocketEnabled() {
return this.webSocketEnabled;
}
public SockJsServiceRegistration setInterceptors(HandshakeInterceptor... interceptors) {
if (!ObjectUtils.isEmpty(interceptors)) {
this.handshakeInterceptors.addAll(Arrays.asList(interceptors));
} }
return this; return this;
} }
protected List<HandshakeInterceptor> getInterceptors() { public SockJsServiceRegistration setTransportHandlerOverrides(TransportHandler... handlers) {
return this.handshakeInterceptors; this.transportHandlerOverrides.clear();
if (!ObjectUtils.isEmpty(handlers)) {
this.transportHandlerOverrides.addAll(Arrays.asList(handlers));
}
return this;
}
public SockJsServiceRegistration setInterceptors(HandshakeInterceptor... interceptors) {
if (!ObjectUtils.isEmpty(interceptors)) {
this.interceptors.addAll(Arrays.asList(interceptors));
}
return this;
} }
protected SockJsService getSockJsService(String[] sockJsPrefixes) { protected SockJsService getSockJsService(String[] sockJsPrefixes) {
@ -228,33 +195,34 @@ public class SockJsServiceRegistration {
if (sockJsPrefixes != null) { if (sockJsPrefixes != null) {
service.setValidSockJsPrefixes(sockJsPrefixes); service.setValidSockJsPrefixes(sockJsPrefixes);
} }
if (getClientLibraryUrl() != null) { if (this.clientLibraryUrl != null) {
service.setSockJsClientLibraryUrl(getClientLibraryUrl()); service.setSockJsClientLibraryUrl(this.clientLibraryUrl);
} }
if (getStreamBytesLimit() != null) { if (this.streamBytesLimit != null) {
service.setStreamBytesLimit(getStreamBytesLimit()); service.setStreamBytesLimit(this.streamBytesLimit);
} }
if (getDummySessionCookieEnabled() != null) { if (this.sessionCookieEnabled != null) {
service.setDummySessionCookieEnabled(getDummySessionCookieEnabled()); service.setDummySessionCookieEnabled(this.sessionCookieEnabled);
} }
if (getHeartbeatTime() != null) { if (this.heartbeatTime != null) {
service.setHeartbeatTime(getHeartbeatTime()); service.setHeartbeatTime(this.heartbeatTime);
} }
if (getDisconnectDelay() != null) { if (this.disconnectDelay != null) {
service.setDisconnectDelay(getDisconnectDelay()); service.setDisconnectDelay(this.heartbeatTime);
} }
if (getHttpMessageCacheSize() != null) { if (this.httpMessageCacheSize != null) {
service.setHttpMessageCacheSize(getHttpMessageCacheSize()); service.setHttpMessageCacheSize(this.httpMessageCacheSize);
} }
if (getWebSocketEnabled() != null) { if (this.webSocketEnabled != null) {
service.setWebSocketsEnabled(getWebSocketEnabled()); service.setWebSocketsEnabled(this.webSocketEnabled);
} }
service.setHandshakeInterceptors(getInterceptors()); service.setHandshakeInterceptors(this.interceptors);
return service; return service;
} }
protected DefaultSockJsService createSockJsService() { private DefaultSockJsService createSockJsService() {
return new DefaultSockJsService(getTaskScheduler()); return new DefaultSockJsService(this.taskScheduler, this.transportHandlers,
this.transportHandlerOverrides.toArray(new TransportHandler[this.transportHandlerOverrides.size()]));
} }
} }

View File

@ -32,8 +32,7 @@ public class WebSocketConfigurationSupport {
@Bean @Bean
public HandlerMapping webSocketHandlerMapping() { public HandlerMapping webSocketHandlerMapping() {
WebSocketHandlerRegistry registry = new WebSocketHandlerRegistry(); WebSocketHandlerRegistry registry = new WebSocketHandlerRegistry(defaultSockJsTaskScheduler());
registry.setDefaultTaskScheduler(sockJsTaskScheduler());
registerWebSocketHandlers(registry); registerWebSocketHandlers(registry);
return registry.getHandlerMapping(); return registry.getHandlerMapping();
} }
@ -41,11 +40,14 @@ public class WebSocketConfigurationSupport {
protected void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { protected void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
} }
/**
* The default TaskScheduler to use if none is configured via
* {@link SockJsServiceRegistration#setTaskScheduler()}
*/
@Bean @Bean
public ThreadPoolTaskScheduler sockJsTaskScheduler() { public ThreadPoolTaskScheduler defaultSockJsTaskScheduler() {
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
scheduler.setThreadNamePrefix("SockJS-"); scheduler.setThreadNamePrefix("SockJS-");
scheduler.setPoolSize(10);
return scheduler; return scheduler;
} }

View File

@ -48,11 +48,17 @@ public class WebSocketHandlerRegistration {
private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>(); private final List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
private HandshakeHandler handshakeHandler;
private SockJsServiceRegistration sockJsServiceRegistration; private SockJsServiceRegistration sockJsServiceRegistration;
private TaskScheduler defaultTaskScheduler; private final TaskScheduler defaultTaskScheduler;
public WebSocketHandlerRegistration(TaskScheduler defaultTaskScheduler) {
this.defaultTaskScheduler = defaultTaskScheduler;
}
public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) { public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) {
Assert.notNull(handler); Assert.notNull(handler);
Assert.notEmpty(paths); Assert.notEmpty(paths);
@ -60,67 +66,55 @@ public class WebSocketHandlerRegistration {
return this; return this;
} }
protected MultiValueMap<WebSocketHandler, String> getHandlerMap() { public WebSocketHandlerRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) {
return this.handlerMap; this.handshakeHandler = handshakeHandler;
return this;
}
public HandshakeHandler getHandshakeHandler() {
return handshakeHandler;
} }
public void addInterceptors(HandshakeInterceptor... interceptors) { public void addInterceptors(HandshakeInterceptor... interceptors) {
this.interceptors.addAll(Arrays.asList(interceptors)); this.interceptors.addAll(Arrays.asList(interceptors));
} }
protected List<HandshakeInterceptor> getInterceptors() {
return this.interceptors;
}
public SockJsServiceRegistration withSockJS() { public SockJsServiceRegistration withSockJS() {
this.sockJsServiceRegistration = new SockJsServiceRegistration(this.defaultTaskScheduler); this.sockJsServiceRegistration = new SockJsServiceRegistration(this.defaultTaskScheduler);
this.sockJsServiceRegistration.setInterceptors( this.sockJsServiceRegistration.setInterceptors(
getInterceptors().toArray(new HandshakeInterceptor[getInterceptors().size()])); this.interceptors.toArray(new HandshakeInterceptor[this.interceptors.size()]));
return this.sockJsServiceRegistration; return this.sockJsServiceRegistration;
} }
protected SockJsServiceRegistration getSockJsServiceRegistration() { MultiValueMap<HttpRequestHandler, String> getMappings() {
return this.sockJsServiceRegistration;
}
protected void setDefaultTaskScheduler(TaskScheduler defaultTaskScheduler) {
this.defaultTaskScheduler = defaultTaskScheduler;
}
protected TaskScheduler getDefaultTaskScheduler() {
return this.defaultTaskScheduler;
}
protected MultiValueMap<HttpRequestHandler, String> getMappings() {
MultiValueMap<HttpRequestHandler, String> mappings = new LinkedMultiValueMap<HttpRequestHandler, String>(); MultiValueMap<HttpRequestHandler, String> mappings = new LinkedMultiValueMap<HttpRequestHandler, String>();
if (getSockJsServiceRegistration() == null) { if (this.sockJsServiceRegistration == null) {
HandshakeHandler handshakeHandler = createHandshakeHandler(); HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler();
for (WebSocketHandler handler : getHandlerMap().keySet()) { for (WebSocketHandler handler : this.handlerMap.keySet()) {
for (String path : getHandlerMap().get(handler)) { for (String path : this.handlerMap.get(handler)) {
WebSocketHttpRequestHandler httpHandler = new WebSocketHttpRequestHandler(handler, handshakeHandler); WebSocketHttpRequestHandler httpHandler = new WebSocketHttpRequestHandler(handler, handshakeHandler);
httpHandler.setHandshakeInterceptors(getInterceptors()); httpHandler.setHandshakeInterceptors(this.interceptors);
mappings.add(httpHandler, path); mappings.add(httpHandler, path);
} }
} }
} }
else { else {
SockJsService sockJsService = getSockJsServiceRegistration().getSockJsService(getAllPrefixes()); SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService(getAllPrefixes());
for (WebSocketHandler handler : getHandlerMap().keySet()) { for (WebSocketHandler handler : this.handlerMap.keySet()) {
for (String path : getHandlerMap().get(handler)) { for (String path : this.handlerMap.get(handler)) {
SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, handler); SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, handler);
mappings.add(httpHandler, path.endsWith("/") ? path + "**" : path + "/**"); mappings.add(httpHandler, path.endsWith("/") ? path + "**" : path + "/**");
} }
} }
} }
return mappings; return mappings;
} }
protected DefaultHandshakeHandler createHandshakeHandler() { private HandshakeHandler getOrCreateHandshakeHandler() {
return new DefaultHandshakeHandler(); return (this.handshakeHandler != null) ? this.handshakeHandler : new DefaultHandshakeHandler();
} }
protected final String[] getAllPrefixes() { private final String[] getAllPrefixes() {
List<String> all = new ArrayList<String>(); List<String> all = new ArrayList<String>();
for (List<String> prefixes: this.handlerMap.values()) { for (List<String> prefixes: this.handlerMap.values()) {
all.addAll(prefixes); all.addAll(prefixes);

View File

@ -22,6 +22,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler; import org.springframework.web.HttpRequestHandler;
import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.HandlerMapping;
@ -45,10 +46,13 @@ public class WebSocketHandlerRegistry {
private TaskScheduler defaultTaskScheduler; private TaskScheduler defaultTaskScheduler;
public WebSocketHandlerRegistry(ThreadPoolTaskScheduler defaultSockJsTaskScheduler) {
this.defaultTaskScheduler = defaultSockJsTaskScheduler;
}
public WebSocketHandlerRegistration addHandler(WebSocketHandler wsHandler, String... paths) { public WebSocketHandlerRegistration addHandler(WebSocketHandler wsHandler, String... paths) {
WebSocketHandlerRegistration r = new WebSocketHandlerRegistration(); WebSocketHandlerRegistration r = new WebSocketHandlerRegistration(this.defaultTaskScheduler);
r.addHandler(wsHandler, paths); r.addHandler(wsHandler, paths);
r.setDefaultTaskScheduler(this.defaultTaskScheduler);
this.registrations.add(r); this.registrations.add(r);
return r; return r;
} }
@ -59,29 +63,16 @@ public class WebSocketHandlerRegistry {
/** /**
* Specify the order to use for WebSocket {@link HandlerMapping} relative to other * Specify the order to use for WebSocket {@link HandlerMapping} relative to other
* handler mappings configured in the Spring MVC configuration. The default value is * handler mappings configured in the Spring MVC configuration. The default value is 1.
* 1.
*/ */
public void setOrder(int order) { public void setOrder(int order) {
this.order = order; this.order = order;
} }
protected int getOrder() {
return this.order;
}
protected void setDefaultTaskScheduler(TaskScheduler defaultTaskScheduler) {
this.defaultTaskScheduler = defaultTaskScheduler;
}
protected TaskScheduler getDefaultTaskScheduler() {
return this.defaultTaskScheduler;
}
/** /**
* Returns a handler mapping with the mapped ViewControllers; or {@code null} in case of no registrations. * Returns a handler mapping with the mapped ViewControllers; or {@code null} in case of no registrations.
*/ */
protected AbstractHandlerMapping getHandlerMapping() { AbstractHandlerMapping getHandlerMapping() {
Map<String, Object> urlMap = new LinkedHashMap<String, Object>(); Map<String, Object> urlMap = new LinkedHashMap<String, Object>();
for (WebSocketHandlerRegistration registration : this.registrations) { for (WebSocketHandlerRegistration registration : this.registrations) {
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();

View File

@ -0,0 +1,104 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket;
import java.util.HashMap;
import java.util.Map;
import org.junit.After;
import org.junit.Before;
import org.junit.runners.Parameterized.Parameter;
import org.springframework.context.Lifecycle;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.support.JettyRequestUpgradeStrategy;
/**
* Base class for WebSocket integration tests.
*
* @author Rossen Stoyanchev
*/
public abstract class AbstractWebSocketIntegrationTests {
private static Map<Class<?>, Class<?>> upgradeStrategyConfigTypes = new HashMap<Class<?>, Class<?>>();
static {
upgradeStrategyConfigTypes.put(JettyTestServer.class, JettyUpgradeStrategyConfig.class);
}
@Parameter(0)
public TestServer server;
@Parameter(1)
public WebSocketClient webSocketClient;
@Before
public void setup() throws Exception {
if (this.webSocketClient instanceof Lifecycle) {
((Lifecycle) this.webSocketClient).start();
}
}
@After
public void teardown() throws Exception {
try {
if (this.webSocketClient instanceof Lifecycle) {
((Lifecycle) this.webSocketClient).stop();
}
}
finally {
this.server.stop();
}
}
protected String getWsBaseUrl() {
return "ws://localhost:" + this.server.getPort();
}
protected Class<?> getUpgradeStrategyConfigClass() {
return upgradeStrategyConfigTypes.get(this.server.getClass());
}
static abstract class AbstractRequestUpgradeStrategyConfig {
@Bean
public HandshakeHandler handshakeHandler() {
return new DefaultHandshakeHandler(requestUpgradeStrategy());
}
public abstract RequestUpgradeStrategy requestUpgradeStrategy();
}
@Configuration
static class JettyUpgradeStrategyConfig extends AbstractRequestUpgradeStrategyConfig {
@Bean
public RequestUpgradeStrategy requestUpgradeStrategy() {
return new JettyRequestUpgradeStrategy();
}
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.springframework.util.SocketUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
/**
* Jetty based {@link TestServer}.
*
* @author Rossen Stoyanchev
*/
public class JettyTestServer implements TestServer {
private final Server jettyServer;
private final int port;
public JettyTestServer() {
this.port = SocketUtils.findAvailableTcpPort();
this.jettyServer = new Server(this.port);
}
@Override
public int getPort() {
return this.port;
}
@Override
public void init(WebApplicationContext cxt) {
ServletContextHandler handler = new ServletContextHandler();
handler.addServlet(new ServletHolder(new DispatcherServlet(cxt)), "/");
this.jettyServer.setHandler(handler);
}
@Override
public void start() throws Exception {
this.jettyServer.start();
}
@Override
public void stop() throws Exception {
if (this.jettyServer.isRunning()) {
this.jettyServer.stop();
}
}
}

View File

@ -0,0 +1,36 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket;
import org.springframework.web.context.WebApplicationContext;
/**
* Contract for a test server to use for integration tests.
*
* @author Rossen Stoyanchev
*/
public interface TestServer {
int getPort();
void init(WebApplicationContext cxt);
void start() throws Exception;
void stop() throws Exception;
}

View File

@ -108,11 +108,9 @@ public class JettyWebSocketClientTests {
factory.setCreator(new WebSocketCreator() { factory.setCreator(new WebSocketCreator() {
@Override @Override
public Object createWebSocket(UpgradeRequest req, UpgradeResponse resp) { public Object createWebSocket(UpgradeRequest req, UpgradeResponse resp) {
if (!CollectionUtils.isEmpty(req.getSubProtocols())) { if (!CollectionUtils.isEmpty(req.getSubProtocols())) {
resp.setAcceptedSubProtocol(req.getSubProtocols().get(0)); resp.setAcceptedSubProtocol(req.getSubProtocols().get(0));
} }
JettyWebSocketSession session = new JettyWebSocketSession(null, null); JettyWebSocketSession session = new JettyWebSocketSession(null, null);
return new JettyWebSocketHandlerAdapter(webSocketHandler, session); return new JettyWebSocketHandlerAdapter(webSocketHandler, session);
} }

View File

@ -18,81 +18,101 @@ package org.springframework.web.socket.server.config;
import java.util.Arrays; import java.util.Arrays;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.web.context.support.GenericWebApplicationContext; import org.junit.runner.RunWith;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.mockito.Mockito;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
import org.springframework.web.socket.JettyTestServer;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler; import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import static org.junit.Assert.*; import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*;
/** /**
* Test fixture for {@link WebSocketConfigurationSupport}. * Test fixture for WebSocket Java config support.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class WebSocketConfigurationTests { @RunWith(Parameterized.class)
public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTests {
private DelegatingWebSocketConfiguration config; @Parameters
public static Iterable<Object[]> arguments() {
private GenericWebApplicationContext context; return Arrays.asList(new Object[][] {
{ new JettyTestServer(), new JettyWebSocketClient()} });
};
@Before @Test
public void setup() { public void registerWebSocketHandler() throws Exception {
this.config = new DelegatingWebSocketConfiguration();
this.context = new GenericWebApplicationContext(); AnnotationConfigWebApplicationContext cxt = new AnnotationConfigWebApplicationContext();
this.context.refresh(); cxt.register(TestWebSocketConfigurer.class, getUpgradeStrategyConfigClass());
this.server.init(cxt);
this.server.start();
WebSocketHandler clientHandler = Mockito.mock(WebSocketHandler.class);
WebSocketHandler serverHandler = cxt.getBean(WebSocketHandler.class);
this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/ws");
verify(serverHandler).afterConnectionEstablished(any(WebSocketSession.class));
verify(clientHandler).afterConnectionEstablished(any(WebSocketSession.class));
} }
@Test @Test
public void webSocket() throws Exception { public void registerWebSocketHandlerWithSockJS() throws Exception {
AnnotationConfigWebApplicationContext cxt = new AnnotationConfigWebApplicationContext();
cxt.register(TestWebSocketConfigurer.class, getUpgradeStrategyConfigClass());
this.server.init(cxt);
this.server.start();
WebSocketHandler clientHandler = Mockito.mock(WebSocketHandler.class);
WebSocketHandler serverHandler = cxt.getBean(WebSocketHandler.class);
this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/sockjs/websocket");
verify(serverHandler).afterConnectionEstablished(any(WebSocketSession.class));
verify(clientHandler).afterConnectionEstablished(any(WebSocketSession.class));
}
@Configuration
@EnableWebSocket
static class TestWebSocketConfigurer implements WebSocketConfigurer {
@Autowired
private HandshakeHandler handshakeHandler; // can't rely on classpath for server detection
final WebSocketHandler handler = new TextWebSocketHandlerAdapter();
WebSocketConfigurer configurer = new WebSocketConfigurer() {
@Override @Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(handler, "/h1");
}
};
this.config.setConfigurers(Arrays.asList(configurer)); registry.addHandler(serverHandler(), "/ws")
SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) this.config.webSocketHandlerMapping(); .setHandshakeHandler(this.handshakeHandler);
hm.setApplicationContext(this.context);
Object actual = hm.getUrlMap().get("/h1"); registry.addHandler(serverHandler(), "/sockjs").withSockJS()
.setTransportHandlerOverrides(new WebSocketTransportHandler(this.handshakeHandler));
assertNotNull(actual);
assertEquals(WebSocketHttpRequestHandler.class, actual.getClass());
assertEquals(1, hm.getUrlMap().size());
} }
@Test @Bean
public void webSocketWithSockJS() throws Exception { public WebSocketHandler serverHandler() {
return Mockito.mock(WebSocketHandler.class);
final WebSocketHandler handler = new TextWebSocketHandlerAdapter();
WebSocketConfigurer configurer = new WebSocketConfigurer() {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(handler, "/h1").withSockJS();
} }
};
this.config.setConfigurers(Arrays.asList(configurer));
SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) this.config.webSocketHandlerMapping();
hm.setApplicationContext(this.context);
Object actual = hm.getUrlMap().get("/h1/**");
assertNotNull(actual);
assertEquals(SockJsHttpRequestHandler.class, actual.getClass());
assertEquals(1, hm.getUrlMap().size());
} }
} }

View File

@ -10,6 +10,10 @@
</layout> </layout>
</appender> </appender>
<logger name="org.springframework.web">
<level value="debug" />
</logger>
<logger name="org.springframework.web.socket"> <logger name="org.springframework.web.socket">
<level value="debug" /> <level value="debug" />
</logger> </logger>