From f0dda0e38b73a9bbc4816f0cad2bd7ccdfa7911e Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 28 Aug 2013 21:08:17 -0400 Subject: [PATCH] Add WebSocket integration tests w/ Java configuration Issue: SPR-10835 --- build.gradle | 5 +- .../SubProtocolWebSocketHandler.java | 7 ++ .../config/StompEndpointRegistration.java | 60 ++++----- .../simp/config/StompEndpointRegistry.java | 29 +---- ...cketMessageBrokerConfigurationSupport.java | 15 +-- .../AbstractWebSocketIntegrationTests.java | 104 ++++++++++++++++ .../messaging/simp/JettyTestServer.java | 69 +++++++++++ .../messaging/simp/TestServer.java | 36 ++++++ ...SocketMessageBrokerConfigurationTests.java | 113 +++++++++-------- spring-messaging/src/test/resources/log4j.xml | 4 + .../config/SockJsServiceRegistration.java | 114 +++++++---------- .../config/WebSocketConfigurationSupport.java | 10 +- .../config/WebSocketHandlerRegistration.java | 60 ++++----- .../config/WebSocketHandlerRegistry.java | 25 ++-- .../AbstractWebSocketIntegrationTests.java | 104 ++++++++++++++++ .../web/socket/JettyTestServer.java | 68 ++++++++++ .../web/socket/TestServer.java | 36 ++++++ .../jetty/JettyWebSocketClientTests.java | 2 - .../config/WebSocketConfigurationTests.java | 116 ++++++++++-------- spring-websocket/src/test/resources/log4j.xml | 4 + 20 files changed, 686 insertions(+), 295 deletions(-) create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/AbstractWebSocketIntegrationTests.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/JettyTestServer.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/TestServer.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/JettyTestServer.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/TestServer.java diff --git a/build.gradle b/build.gradle index 84da1fd87c..e087211bca 100644 --- a/build.gradle +++ b/build.gradle @@ -334,7 +334,8 @@ project("spring-messaging") { testCompile("org.eclipse.jetty:jetty-webapp:9.0.5.v20130815") { exclude group: "org.eclipse.jetty.orbit", module: "javax.servlet" } - testCompile("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815") + optional("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815") + optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815") testCompile("javax.servlet:javax.servlet-api:3.0.1") testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") testCompile("log4j:log4j:1.2.17") @@ -527,6 +528,8 @@ project("spring-websocket") { optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815") optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") optional("org.codehaus.jackson:jackson-mapper-asl:1.9.12") + testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") + testCompile("log4j:log4j:1.2.17") } repositories { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java index 8819ed206a..01d71a9cfb 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java @@ -72,6 +72,7 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHan this.outputChannel = outputChannel; } + /** * Configure one or more handlers to use depending on the sub-protocol requested by * the client in the WebSocket handshake request. @@ -130,6 +131,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHan return this.defaultProtocolHandler; } + /** + * Return all supported protocols. + */ + public Set getSupportedProtocols() { + return this.protocolHandlers.keySet(); + } @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompEndpointRegistration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompEndpointRegistration.java index 05535810ef..556e96ba78 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompEndpointRegistration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompEndpointRegistration.java @@ -19,11 +19,13 @@ package org.springframework.messaging.simp.config; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Set; import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler; import org.springframework.scheduling.TaskScheduler; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; import org.springframework.web.HttpRequestHandler; import org.springframework.web.socket.server.DefaultHandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler; @@ -46,53 +48,43 @@ public class StompEndpointRegistration { private final SubProtocolWebSocketHandler wsHandler; + private HandshakeHandler handshakeHandler; + private StompSockJsServiceRegistration sockJsServiceRegistration; - private TaskScheduler defaultTaskScheduler; + private final TaskScheduler defaultSockJsTaskScheduler; - public StompEndpointRegistration(Collection paths, SubProtocolWebSocketHandler webSocketHandler) { + public StompEndpointRegistration(Collection paths, SubProtocolWebSocketHandler webSocketHandler, + TaskScheduler defaultSockJsTaskScheduler) { + this.paths = new ArrayList(paths); this.wsHandler = webSocketHandler; + this.defaultSockJsTaskScheduler = defaultSockJsTaskScheduler; } - protected List getPaths() { - return this.paths; - } - - protected SubProtocolWebSocketHandler getSubProtocolWebSocketHandler() { - return this.wsHandler; - } - - protected StompSockJsServiceRegistration getSockJsServiceRegistration() { - return this.sockJsServiceRegistration; + public StompEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) { + this.handshakeHandler = handshakeHandler; + return this; } public SockJsServiceRegistration withSockJS() { - this.sockJsServiceRegistration = new StompSockJsServiceRegistration(this.defaultTaskScheduler); + this.sockJsServiceRegistration = new StompSockJsServiceRegistration(this.defaultSockJsTaskScheduler); return this.sockJsServiceRegistration; } - protected void setDefaultTaskScheduler(TaskScheduler defaultTaskScheduler) { - this.defaultTaskScheduler = defaultTaskScheduler; - } - - protected TaskScheduler getDefaultTaskScheduler() { - return this.defaultTaskScheduler; - } - protected MultiValueMap getMappings() { MultiValueMap mappings = new LinkedMultiValueMap(); - if (getSockJsServiceRegistration() == null) { - HandshakeHandler handshakeHandler = createHandshakeHandler(); - for (String path : getPaths()) { + if (this.sockJsServiceRegistration == null) { + HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler(); + for (String path : this.paths) { WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(this.wsHandler, handshakeHandler); mappings.add(handler, path); } } else { - SockJsService sockJsService = getSockJsServiceRegistration().getSockJsService(); + SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService(); for (String path : this.paths) { SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, this.wsHandler); mappings.add(httpHandler, path.endsWith("/") ? path + "**" : path + "/**"); @@ -101,8 +93,20 @@ public class StompEndpointRegistration { return mappings; } - protected DefaultHandshakeHandler createHandshakeHandler() { - return new DefaultHandshakeHandler(); + private HandshakeHandler getOrCreateHandshakeHandler() { + + HandshakeHandler handler = (this.handshakeHandler != null) + ? this.handshakeHandler : new DefaultHandshakeHandler(); + + if (handler instanceof DefaultHandshakeHandler) { + DefaultHandshakeHandler defaultHandshakeHandler = (DefaultHandshakeHandler) handler; + if (ObjectUtils.isEmpty(defaultHandshakeHandler.getSupportedProtocols())) { + Set protocols = this.wsHandler.getSupportedProtocols(); + defaultHandshakeHandler.setSupportedProtocols(protocols.toArray(new String[protocols.size()])); + } + } + + return handler; } @@ -114,7 +118,7 @@ public class StompEndpointRegistration { } protected SockJsService getSockJsService() { - return super.getSockJsService(getPaths().toArray(new String[getPaths().size()])); + return super.getSockJsService(paths.toArray(new String[paths.size()])); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompEndpointRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompEndpointRegistry.java index b608308187..111c4acf9d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompEndpointRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompEndpointRegistry.java @@ -51,11 +51,11 @@ public class StompEndpointRegistry { private int order = 1; - private TaskScheduler defaultTaskScheduler; + private final TaskScheduler defaultSockJsTaskScheduler; public StompEndpointRegistry(SubProtocolWebSocketHandler webSocketHandler, - MutableUserQueueSuffixResolver userQueueSuffixResolver) { + MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) { Assert.notNull(webSocketHandler); Assert.notNull(userQueueSuffixResolver); @@ -63,25 +63,18 @@ public class StompEndpointRegistry { this.wsHandler = webSocketHandler; this.stompHandler = new StompProtocolHandler(); this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver); + this.defaultSockJsTaskScheduler = defaultSockJsTaskScheduler; } public StompEndpointRegistration addEndpoint(String... paths) { this.wsHandler.addProtocolHandler(this.stompHandler); - StompEndpointRegistration r = new StompEndpointRegistration(Arrays.asList(paths), this.wsHandler); - r.setDefaultTaskScheduler(getDefaultTaskScheduler()); + StompEndpointRegistration r = new StompEndpointRegistration( + Arrays.asList(paths), this.wsHandler, this.defaultSockJsTaskScheduler); this.registrations.add(r); return r; } - protected SubProtocolWebSocketHandler getSubProtocolWebSocketHandler() { - return this.wsHandler; - } - - protected StompProtocolHandler getStompProtocolHandler() { - return this.stompHandler; - } - /** * Specify the order to use for the STOMP endpoint {@link HandlerMapping} relative to * other handler mappings configured in the Spring MVC configuration. The default @@ -91,18 +84,6 @@ public class StompEndpointRegistry { this.order = order; } - protected int getOrder() { - return this.order; - } - - protected void setDefaultTaskScheduler(TaskScheduler defaultTaskScheduler) { - this.defaultTaskScheduler = defaultTaskScheduler; - } - - protected TaskScheduler getDefaultTaskScheduler() { - return this.defaultTaskScheduler; - } - /** * Returns a handler mapping with the mapped ViewControllers; or {@code null} in case of no registrations. */ diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java index 024a2a62fa..c11fd293ce 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java @@ -33,6 +33,7 @@ import org.springframework.messaging.support.converter.MessageConverter; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.web.servlet.HandlerMapping; +import org.springframework.web.socket.server.config.SockJsServiceRegistration; /** @@ -54,9 +55,8 @@ public abstract class WebSocketMessageBrokerConfigurationSupport { @Bean public HandlerMapping brokerWebSocketHandlerMapping() { - StompEndpointRegistry registry = - new StompEndpointRegistry(subProtocolWebSocketHandler(), userQueueSuffixResolver()); - registry.setDefaultTaskScheduler(brokerDefaultSockJsTaskScheduler()); + StompEndpointRegistry registry = new StompEndpointRegistry( + subProtocolWebSocketHandler(), userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler()); registerStompEndpoints(registry); return registry.getHandlerMapping(); } @@ -73,11 +73,14 @@ public abstract class WebSocketMessageBrokerConfigurationSupport { return new SimpleUserQueueSuffixResolver(); } + /** + * The default TaskScheduler to use if none is configured via + * {@link SockJsServiceRegistration#setTaskScheduler()} + */ @Bean public ThreadPoolTaskScheduler brokerDefaultSockJsTaskScheduler() { ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); scheduler.setThreadNamePrefix("BrokerSockJS-"); - scheduler.setPoolSize(10); return scheduler; } @@ -97,9 +100,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport { @Bean public ThreadPoolTaskExecutor webSocketChannelExecutor() { ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); - executor.setCorePoolSize(4); - executor.setCorePoolSize(8); - executor.setThreadNamePrefix("MessageChannel-"); + executor.setThreadNamePrefix("BrokerWebSocketChannel-"); return executor; } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/AbstractWebSocketIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/AbstractWebSocketIntegrationTests.java new file mode 100644 index 0000000000..2df63e6db6 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/AbstractWebSocketIntegrationTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.runners.Parameterized.Parameter; +import org.springframework.context.Lifecycle; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.client.WebSocketClient; +import org.springframework.web.socket.server.DefaultHandshakeHandler; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.support.JettyRequestUpgradeStrategy; + + + +/** + * Base class for WebSocket integration tests. + * + * @author Rossen Stoyanchev + */ +public abstract class AbstractWebSocketIntegrationTests { + + private static Map, Class> upgradeStrategyConfigTypes = new HashMap, Class>(); + + static { + upgradeStrategyConfigTypes.put(JettyTestServer.class, JettyUpgradeStrategyConfig.class); + } + + @Parameter(0) + public TestServer server; + + @Parameter(1) + public WebSocketClient webSocketClient; + + + @Before + public void setup() throws Exception { + if (this.webSocketClient instanceof Lifecycle) { + ((Lifecycle) this.webSocketClient).start(); + } + } + + @After + public void teardown() throws Exception { + try { + if (this.webSocketClient instanceof Lifecycle) { + ((Lifecycle) this.webSocketClient).stop(); + } + } + finally { + this.server.stop(); + } + } + + protected String getWsBaseUrl() { + return "ws://localhost:" + this.server.getPort(); + } + + protected Class getUpgradeStrategyConfigClass() { + return upgradeStrategyConfigTypes.get(this.server.getClass()); + } + + + static abstract class AbstractRequestUpgradeStrategyConfig { + + @Bean + public HandshakeHandler handshakeHandler() { + return new DefaultHandshakeHandler(requestUpgradeStrategy()); + } + + public abstract RequestUpgradeStrategy requestUpgradeStrategy(); + } + + + @Configuration + static class JettyUpgradeStrategyConfig extends AbstractRequestUpgradeStrategyConfig { + + @Bean + public RequestUpgradeStrategy requestUpgradeStrategy() { + return new JettyRequestUpgradeStrategy(); + } + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/JettyTestServer.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/JettyTestServer.java new file mode 100644 index 0000000000..a904ce668c --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/JettyTestServer.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp; + +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.springframework.util.SocketUtils; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.socket.TestServer; + + +/** + * Jetty based {@link TestServer}. + * + * @author Rossen Stoyanchev + */ +public class JettyTestServer implements TestServer { + + private final Server jettyServer; + + private final int port; + + + public JettyTestServer() { + this.port = SocketUtils.findAvailableTcpPort(); + this.jettyServer = new Server(this.port); + } + + @Override + public int getPort() { + return this.port; + } + + @Override + public void init(WebApplicationContext cxt) { + ServletContextHandler handler = new ServletContextHandler(); + handler.addServlet(new ServletHolder(new DispatcherServlet(cxt)), "/"); + this.jettyServer.setHandler(handler); + } + + @Override + public void start() throws Exception { + this.jettyServer.start(); + } + + @Override + public void stop() throws Exception { + if (this.jettyServer.isRunning()) { + this.jettyServer.stop(); + } + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/TestServer.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/TestServer.java new file mode 100644 index 0000000000..a9a510fd2b --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/TestServer.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp; + +import org.springframework.web.context.WebApplicationContext; + +/** + * Contract for a test server to use for integration tests. + * + * @author Rossen Stoyanchev + */ +public interface TestServer { + + int getPort(); + + void init(WebApplicationContext cxt); + + void start() throws Exception; + + void stop() throws Exception; + +} \ No newline at end of file diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationTests.java index 1dfc7cc996..79f47d8e7e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationTests.java @@ -16,28 +16,37 @@ package org.springframework.messaging.simp.config; -import org.junit.Before; +import java.util.Arrays; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + import org.junit.Test; -import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.Message; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.handler.annotation.MessageMapping; -import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler; +import org.springframework.messaging.simp.AbstractWebSocketIntegrationTests; +import org.springframework.messaging.simp.JettyTestServer; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.stomp.StompMessageConverter; import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; import org.springframework.stereotype.Controller; -import org.springframework.web.servlet.HandlerMapping; -import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; +import org.springframework.web.socket.client.jetty.JettyWebSocketClient; +import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.config.WebSocketConfigurationSupport; -import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; -import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler; -import org.springframework.web.socket.support.TestWebSocketSession; +import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; import static org.junit.Assert.*; @@ -47,65 +56,47 @@ import static org.junit.Assert.*; * * @author Rossen Stoyanchev */ -public class WebSocketMessageBrokerConfigurationTests { +@RunWith(Parameterized.class) +public class WebSocketMessageBrokerConfigurationTests extends AbstractWebSocketIntegrationTests { + + @Parameters + public static Iterable arguments() { + return Arrays.asList(new Object[][] { + { new JettyTestServer(), new JettyWebSocketClient()} }); + }; - @Before - public void setup() { - } @Test - public void webSocketHandler() throws Exception { + public void sendMessage() throws Exception { - AnnotationConfigApplicationContext cxt = new AnnotationConfigApplicationContext(); + AnnotationConfigWebApplicationContext cxt = new AnnotationConfigWebApplicationContext(); cxt.register(TestWebSocketMessageBrokerConfiguration.class, SimpleBrokerConfigurer.class); - cxt.refresh(); + cxt.register(getUpgradeStrategyConfigClass()); - SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) cxt.getBean(HandlerMapping.class); - Object actual = hm.getUrlMap().get("/e1"); - - assertNotNull(actual); - assertEquals(WebSocketHttpRequestHandler.class, actual.getClass()); - - cxt.close(); - } - - @Test - public void webSocketHandlerWithSockJS() throws Exception { - - AnnotationConfigApplicationContext cxt = new AnnotationConfigApplicationContext(); - cxt.register(TestWebSocketMessageBrokerConfiguration.class, SimpleBrokerConfigurer.class); - cxt.refresh(); - - SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) cxt.getBean(HandlerMapping.class); - Object actual = hm.getUrlMap().get("/e2/**"); - - assertNotNull(actual); - assertEquals(SockJsHttpRequestHandler.class, actual.getClass()); - - cxt.close(); - } - - @Test - public void annotationMethodMessageHandler() throws Exception { - - AnnotationConfigApplicationContext cxt = new AnnotationConfigApplicationContext(); - cxt.register(TestWebSocketMessageBrokerConfiguration.class, SimpleBrokerConfigurer.class); - cxt.refresh(); + this.server.init(cxt); + this.server.start(); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); headers.setDestination("/app/foo"); Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); byte[] bytes = new StompMessageConverter().fromMessage(message); + final TextMessage webSocketMessage = new TextMessage(new String(bytes)); - TestWebSocketSession session = new TestWebSocketSession(); - session.setAcceptedProtocol("v12.stomp"); + WebSocketHandler clientHandler = new TextWebSocketHandlerAdapter() { + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + session.sendMessage(webSocketMessage); + } + }; - SubProtocolWebSocketHandler wsHandler = cxt.getBean(SubProtocolWebSocketHandler.class); - wsHandler.handleMessage(session, new TextMessage(new String(bytes))); + TestController testController = cxt.getBean(TestController.class); - assertTrue(cxt.getBean(TestController.class).foo); + this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/ws"); + assertTrue(testController.latch.await(2, TimeUnit.SECONDS)); - cxt.close(); + testController.latch = new CountDownLatch(1); + this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/sockjs/websocket"); + assertTrue(testController.latch.await(2, TimeUnit.SECONDS)); } @@ -128,16 +119,23 @@ public class WebSocketMessageBrokerConfigurationTests { public TestController testController() { return new TestController(); } - } @Configuration static class SimpleBrokerConfigurer implements WebSocketMessageBrokerConfigurer { + @Autowired + private HandshakeHandler handshakeHandler; // can't rely on classpath for server detection + + @Override public void registerStompEndpoints(StompEndpointRegistry registry) { - registry.addEndpoint("/e1"); - registry.addEndpoint("/e2").withSockJS(); + + registry.addEndpoint("/ws") + .setHandshakeHandler(this.handshakeHandler); + + registry.addEndpoint("/sockjs").withSockJS() + .setTransportHandlerOverrides(new WebSocketTransportHandler(this.handshakeHandler));; } @Override @@ -150,12 +148,11 @@ public class WebSocketMessageBrokerConfigurationTests { @Controller private static class TestController { - private boolean foo; - + private CountDownLatch latch = new CountDownLatch(1); @MessageMapping(value="/app/foo") public void handleFoo() { - this.foo = true; + this.latch.countDown(); } } diff --git a/spring-messaging/src/test/resources/log4j.xml b/spring-messaging/src/test/resources/log4j.xml index 79762741be..cce3d56450 100644 --- a/spring-messaging/src/test/resources/log4j.xml +++ b/spring-messaging/src/test/resources/log4j.xml @@ -19,6 +19,10 @@ + + + + diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/SockJsServiceRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/SockJsServiceRegistration.java index 82ef18255c..62508e3342 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/SockJsServiceRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/SockJsServiceRegistration.java @@ -24,6 +24,7 @@ import org.springframework.scheduling.TaskScheduler; import org.springframework.util.ObjectUtils; import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.sockjs.SockJsService; +import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; @@ -53,7 +54,11 @@ public class SockJsServiceRegistration { private Boolean webSocketEnabled; - private final List handshakeInterceptors = new ArrayList(); + private final List transportHandlers = new ArrayList(); + + private final List transportHandlerOverrides = new ArrayList(); + + private final List interceptors = new ArrayList(); public SockJsServiceRegistration(TaskScheduler defaultTaskScheduler) { @@ -66,10 +71,6 @@ public class SockJsServiceRegistration { return this; } - protected TaskScheduler getTaskScheduler() { - return this.taskScheduler; - } - /** * Transports which don't support cross-domain communication natively (e.g. * "eventsource", "htmlfile") rely on serving a simple page (using the @@ -87,14 +88,6 @@ public class SockJsServiceRegistration { return this; } - /** - * The URL to the SockJS JavaScript client library. - * @see #setSockJsClientLibraryUrl(String) - */ - protected String getClientLibraryUrl() { - return this.clientLibraryUrl; - } - /** * Streaming transports save responses on the client side and don't free * memory used by delivered messages. Such transports need to recycle the @@ -111,10 +104,6 @@ public class SockJsServiceRegistration { return this; } - protected Integer getStreamBytesLimit() { - return this.streamBytesLimit; - } - /** * Some load balancers do sticky sessions, but only if there is a "JSESSIONID" * cookie. Even if it is set to a dummy value, it doesn't matter since @@ -127,14 +116,6 @@ public class SockJsServiceRegistration { return this; } - /** - * Whether setting JSESSIONID cookie is necessary. - * @see #setDummySessionCookieEnabled(boolean) - */ - protected Boolean getDummySessionCookieEnabled() { - return this.sessionCookieEnabled; - } - /** * The amount of time in milliseconds when the server has not sent any * messages and after which the server should send a heartbeat frame to the @@ -147,10 +128,6 @@ public class SockJsServiceRegistration { return this; } - protected Long getHeartbeatTime() { - return this.heartbeatTime; - } - /** * The amount of time in milliseconds before a client is considered * disconnected after not having a receiving connection, i.e. an active @@ -163,13 +140,6 @@ public class SockJsServiceRegistration { return this; } - /** - * Return the amount of time in milliseconds before a client is considered disconnected. - */ - protected Long getDisconnectDelay() { - return this.disconnectDelay; - } - /** * The number of server-to-client messages that a session can cache while waiting for * the next HTTP polling request from the client. All HTTP transports use this @@ -186,13 +156,6 @@ public class SockJsServiceRegistration { return this; } - /** - * Return the size of the HTTP message cache. - */ - protected Integer getHttpMessageCacheSize() { - return this.httpMessageCacheSize; - } - /** * Some load balancers don't support WebSocket. This option can be used to * disable the WebSocket transport on the server side. @@ -204,23 +167,27 @@ public class SockJsServiceRegistration { return this; } - /** - * Whether WebSocket transport is enabled. - * @see #setWebSocketsEnabled(boolean) - */ - protected Boolean getWebSocketEnabled() { - return this.webSocketEnabled; - } - - public SockJsServiceRegistration setInterceptors(HandshakeInterceptor... interceptors) { - if (!ObjectUtils.isEmpty(interceptors)) { - this.handshakeInterceptors.addAll(Arrays.asList(interceptors)); + public SockJsServiceRegistration setTransportHandlers(TransportHandler... handlers) { + this.transportHandlers.clear(); + if (!ObjectUtils.isEmpty(handlers)) { + this.transportHandlers.addAll(Arrays.asList(handlers)); } return this; } - protected List getInterceptors() { - return this.handshakeInterceptors; + public SockJsServiceRegistration setTransportHandlerOverrides(TransportHandler... handlers) { + this.transportHandlerOverrides.clear(); + if (!ObjectUtils.isEmpty(handlers)) { + this.transportHandlerOverrides.addAll(Arrays.asList(handlers)); + } + return this; + } + + public SockJsServiceRegistration setInterceptors(HandshakeInterceptor... interceptors) { + if (!ObjectUtils.isEmpty(interceptors)) { + this.interceptors.addAll(Arrays.asList(interceptors)); + } + return this; } protected SockJsService getSockJsService(String[] sockJsPrefixes) { @@ -228,33 +195,34 @@ public class SockJsServiceRegistration { if (sockJsPrefixes != null) { service.setValidSockJsPrefixes(sockJsPrefixes); } - if (getClientLibraryUrl() != null) { - service.setSockJsClientLibraryUrl(getClientLibraryUrl()); + if (this.clientLibraryUrl != null) { + service.setSockJsClientLibraryUrl(this.clientLibraryUrl); } - if (getStreamBytesLimit() != null) { - service.setStreamBytesLimit(getStreamBytesLimit()); + if (this.streamBytesLimit != null) { + service.setStreamBytesLimit(this.streamBytesLimit); } - if (getDummySessionCookieEnabled() != null) { - service.setDummySessionCookieEnabled(getDummySessionCookieEnabled()); + if (this.sessionCookieEnabled != null) { + service.setDummySessionCookieEnabled(this.sessionCookieEnabled); } - if (getHeartbeatTime() != null) { - service.setHeartbeatTime(getHeartbeatTime()); + if (this.heartbeatTime != null) { + service.setHeartbeatTime(this.heartbeatTime); } - if (getDisconnectDelay() != null) { - service.setDisconnectDelay(getDisconnectDelay()); + if (this.disconnectDelay != null) { + service.setDisconnectDelay(this.heartbeatTime); } - if (getHttpMessageCacheSize() != null) { - service.setHttpMessageCacheSize(getHttpMessageCacheSize()); + if (this.httpMessageCacheSize != null) { + service.setHttpMessageCacheSize(this.httpMessageCacheSize); } - if (getWebSocketEnabled() != null) { - service.setWebSocketsEnabled(getWebSocketEnabled()); + if (this.webSocketEnabled != null) { + service.setWebSocketsEnabled(this.webSocketEnabled); } - service.setHandshakeInterceptors(getInterceptors()); + service.setHandshakeInterceptors(this.interceptors); return service; } - protected DefaultSockJsService createSockJsService() { - return new DefaultSockJsService(getTaskScheduler()); + private DefaultSockJsService createSockJsService() { + return new DefaultSockJsService(this.taskScheduler, this.transportHandlers, + this.transportHandlerOverrides.toArray(new TransportHandler[this.transportHandlerOverrides.size()])); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurationSupport.java index c1ade7bfe9..fd8456f102 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurationSupport.java @@ -32,8 +32,7 @@ public class WebSocketConfigurationSupport { @Bean public HandlerMapping webSocketHandlerMapping() { - WebSocketHandlerRegistry registry = new WebSocketHandlerRegistry(); - registry.setDefaultTaskScheduler(sockJsTaskScheduler()); + WebSocketHandlerRegistry registry = new WebSocketHandlerRegistry(defaultSockJsTaskScheduler()); registerWebSocketHandlers(registry); return registry.getHandlerMapping(); } @@ -41,11 +40,14 @@ public class WebSocketConfigurationSupport { protected void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { } + /** + * The default TaskScheduler to use if none is configured via + * {@link SockJsServiceRegistration#setTaskScheduler()} + */ @Bean - public ThreadPoolTaskScheduler sockJsTaskScheduler() { + public ThreadPoolTaskScheduler defaultSockJsTaskScheduler() { ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); scheduler.setThreadNamePrefix("SockJS-"); - scheduler.setPoolSize(10); return scheduler; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketHandlerRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketHandlerRegistration.java index 5c7fd94d75..0503ce9f98 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketHandlerRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketHandlerRegistration.java @@ -48,11 +48,17 @@ public class WebSocketHandlerRegistration { private final List interceptors = new ArrayList(); + private HandshakeHandler handshakeHandler; + private SockJsServiceRegistration sockJsServiceRegistration; - private TaskScheduler defaultTaskScheduler; + private final TaskScheduler defaultTaskScheduler; + public WebSocketHandlerRegistration(TaskScheduler defaultTaskScheduler) { + this.defaultTaskScheduler = defaultTaskScheduler; + } + public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) { Assert.notNull(handler); Assert.notEmpty(paths); @@ -60,67 +66,55 @@ public class WebSocketHandlerRegistration { return this; } - protected MultiValueMap getHandlerMap() { - return this.handlerMap; + public WebSocketHandlerRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) { + this.handshakeHandler = handshakeHandler; + return this; + } + + public HandshakeHandler getHandshakeHandler() { + return handshakeHandler; } public void addInterceptors(HandshakeInterceptor... interceptors) { this.interceptors.addAll(Arrays.asList(interceptors)); } - protected List getInterceptors() { - return this.interceptors; - } - public SockJsServiceRegistration withSockJS() { this.sockJsServiceRegistration = new SockJsServiceRegistration(this.defaultTaskScheduler); this.sockJsServiceRegistration.setInterceptors( - getInterceptors().toArray(new HandshakeInterceptor[getInterceptors().size()])); + this.interceptors.toArray(new HandshakeInterceptor[this.interceptors.size()])); return this.sockJsServiceRegistration; } - protected SockJsServiceRegistration getSockJsServiceRegistration() { - return this.sockJsServiceRegistration; - } - - protected void setDefaultTaskScheduler(TaskScheduler defaultTaskScheduler) { - this.defaultTaskScheduler = defaultTaskScheduler; - } - - protected TaskScheduler getDefaultTaskScheduler() { - return this.defaultTaskScheduler; - } - - protected MultiValueMap getMappings() { + MultiValueMap getMappings() { MultiValueMap mappings = new LinkedMultiValueMap(); - if (getSockJsServiceRegistration() == null) { - HandshakeHandler handshakeHandler = createHandshakeHandler(); - for (WebSocketHandler handler : getHandlerMap().keySet()) { - for (String path : getHandlerMap().get(handler)) { + if (this.sockJsServiceRegistration == null) { + HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler(); + for (WebSocketHandler handler : this.handlerMap.keySet()) { + for (String path : this.handlerMap.get(handler)) { WebSocketHttpRequestHandler httpHandler = new WebSocketHttpRequestHandler(handler, handshakeHandler); - httpHandler.setHandshakeInterceptors(getInterceptors()); + httpHandler.setHandshakeInterceptors(this.interceptors); mappings.add(httpHandler, path); } } } else { - SockJsService sockJsService = getSockJsServiceRegistration().getSockJsService(getAllPrefixes()); - for (WebSocketHandler handler : getHandlerMap().keySet()) { - for (String path : getHandlerMap().get(handler)) { + SockJsService sockJsService = this.sockJsServiceRegistration.getSockJsService(getAllPrefixes()); + for (WebSocketHandler handler : this.handlerMap.keySet()) { + for (String path : this.handlerMap.get(handler)) { SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, handler); mappings.add(httpHandler, path.endsWith("/") ? path + "**" : path + "/**"); } } - } return mappings; } - protected DefaultHandshakeHandler createHandshakeHandler() { - return new DefaultHandshakeHandler(); + private HandshakeHandler getOrCreateHandshakeHandler() { + return (this.handshakeHandler != null) ? this.handshakeHandler : new DefaultHandshakeHandler(); } - protected final String[] getAllPrefixes() { + private final String[] getAllPrefixes() { List all = new ArrayList(); for (List prefixes: this.handlerMap.values()) { all.addAll(prefixes); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketHandlerRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketHandlerRegistry.java index ecfdc00d49..e4a9fccb23 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketHandlerRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketHandlerRegistry.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import org.springframework.scheduling.TaskScheduler; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.MultiValueMap; import org.springframework.web.HttpRequestHandler; import org.springframework.web.servlet.HandlerMapping; @@ -45,10 +46,13 @@ public class WebSocketHandlerRegistry { private TaskScheduler defaultTaskScheduler; + public WebSocketHandlerRegistry(ThreadPoolTaskScheduler defaultSockJsTaskScheduler) { + this.defaultTaskScheduler = defaultSockJsTaskScheduler; + } + public WebSocketHandlerRegistration addHandler(WebSocketHandler wsHandler, String... paths) { - WebSocketHandlerRegistration r = new WebSocketHandlerRegistration(); + WebSocketHandlerRegistration r = new WebSocketHandlerRegistration(this.defaultTaskScheduler); r.addHandler(wsHandler, paths); - r.setDefaultTaskScheduler(this.defaultTaskScheduler); this.registrations.add(r); return r; } @@ -59,29 +63,16 @@ public class WebSocketHandlerRegistry { /** * Specify the order to use for WebSocket {@link HandlerMapping} relative to other - * handler mappings configured in the Spring MVC configuration. The default value is - * 1. + * handler mappings configured in the Spring MVC configuration. The default value is 1. */ public void setOrder(int order) { this.order = order; } - protected int getOrder() { - return this.order; - } - - protected void setDefaultTaskScheduler(TaskScheduler defaultTaskScheduler) { - this.defaultTaskScheduler = defaultTaskScheduler; - } - - protected TaskScheduler getDefaultTaskScheduler() { - return this.defaultTaskScheduler; - } - /** * Returns a handler mapping with the mapped ViewControllers; or {@code null} in case of no registrations. */ - protected AbstractHandlerMapping getHandlerMapping() { + AbstractHandlerMapping getHandlerMapping() { Map urlMap = new LinkedHashMap(); for (WebSocketHandlerRegistration registration : this.registrations) { MultiValueMap mappings = registration.getMappings(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java new file mode 100644 index 0000000000..b50dc5238f --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/AbstractWebSocketIntegrationTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.runners.Parameterized.Parameter; +import org.springframework.context.Lifecycle; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.client.WebSocketClient; +import org.springframework.web.socket.server.DefaultHandshakeHandler; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.support.JettyRequestUpgradeStrategy; + + + +/** + * Base class for WebSocket integration tests. + * + * @author Rossen Stoyanchev + */ +public abstract class AbstractWebSocketIntegrationTests { + + private static Map, Class> upgradeStrategyConfigTypes = new HashMap, Class>(); + + static { + upgradeStrategyConfigTypes.put(JettyTestServer.class, JettyUpgradeStrategyConfig.class); + } + + @Parameter(0) + public TestServer server; + + @Parameter(1) + public WebSocketClient webSocketClient; + + + @Before + public void setup() throws Exception { + if (this.webSocketClient instanceof Lifecycle) { + ((Lifecycle) this.webSocketClient).start(); + } + } + + @After + public void teardown() throws Exception { + try { + if (this.webSocketClient instanceof Lifecycle) { + ((Lifecycle) this.webSocketClient).stop(); + } + } + finally { + this.server.stop(); + } + } + + protected String getWsBaseUrl() { + return "ws://localhost:" + this.server.getPort(); + } + + protected Class getUpgradeStrategyConfigClass() { + return upgradeStrategyConfigTypes.get(this.server.getClass()); + } + + + static abstract class AbstractRequestUpgradeStrategyConfig { + + @Bean + public HandshakeHandler handshakeHandler() { + return new DefaultHandshakeHandler(requestUpgradeStrategy()); + } + + public abstract RequestUpgradeStrategy requestUpgradeStrategy(); + } + + + @Configuration + static class JettyUpgradeStrategyConfig extends AbstractRequestUpgradeStrategyConfig { + + @Bean + public RequestUpgradeStrategy requestUpgradeStrategy() { + return new JettyRequestUpgradeStrategy(); + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/JettyTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/JettyTestServer.java new file mode 100644 index 0000000000..66abfb8c7b --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/JettyTestServer.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.springframework.util.SocketUtils; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; + + +/** + * Jetty based {@link TestServer}. + * + * @author Rossen Stoyanchev + */ +public class JettyTestServer implements TestServer { + + private final Server jettyServer; + + private final int port; + + + public JettyTestServer() { + this.port = SocketUtils.findAvailableTcpPort(); + this.jettyServer = new Server(this.port); + } + + @Override + public int getPort() { + return this.port; + } + + @Override + public void init(WebApplicationContext cxt) { + ServletContextHandler handler = new ServletContextHandler(); + handler.addServlet(new ServletHolder(new DispatcherServlet(cxt)), "/"); + this.jettyServer.setHandler(handler); + } + + @Override + public void start() throws Exception { + this.jettyServer.start(); + } + + @Override + public void stop() throws Exception { + if (this.jettyServer.isRunning()) { + this.jettyServer.stop(); + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/TestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/TestServer.java new file mode 100644 index 0000000000..39b23e0552 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/TestServer.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + +import org.springframework.web.context.WebApplicationContext; + +/** + * Contract for a test server to use for integration tests. + * + * @author Rossen Stoyanchev + */ +public interface TestServer { + + int getPort(); + + void init(WebApplicationContext cxt); + + void start() throws Exception; + + void stop() throws Exception; + +} \ No newline at end of file diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java index a761fd9c54..f12f440cd3 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java @@ -108,11 +108,9 @@ public class JettyWebSocketClientTests { factory.setCreator(new WebSocketCreator() { @Override public Object createWebSocket(UpgradeRequest req, UpgradeResponse resp) { - if (!CollectionUtils.isEmpty(req.getSubProtocols())) { resp.setAcceptedSubProtocol(req.getSubProtocols().get(0)); } - JettyWebSocketSession session = new JettyWebSocketSession(null, null); return new JettyWebSocketHandlerAdapter(webSocketHandler, session); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java index dae8faec6d..69eb590979 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/config/WebSocketConfigurationTests.java @@ -18,81 +18,101 @@ package org.springframework.web.socket.server.config; import java.util.Arrays; -import org.junit.Before; import org.junit.Test; -import org.springframework.web.context.support.GenericWebApplicationContext; -import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import org.mockito.Mockito; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.socket.AbstractWebSocketIntegrationTests; +import org.springframework.web.socket.JettyTestServer; import org.springframework.web.socket.WebSocketHandler; -import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; -import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; -import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.client.jetty.JettyWebSocketClient; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; -import static org.junit.Assert.*; +import static org.mockito.Matchers.*; +import static org.mockito.Mockito.*; /** - * Test fixture for {@link WebSocketConfigurationSupport}. + * Test fixture for WebSocket Java config support. * * @author Rossen Stoyanchev */ -public class WebSocketConfigurationTests { +@RunWith(Parameterized.class) +public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTests { - private DelegatingWebSocketConfiguration config; - - private GenericWebApplicationContext context; + @Parameters + public static Iterable arguments() { + return Arrays.asList(new Object[][] { + { new JettyTestServer(), new JettyWebSocketClient()} }); + }; - @Before - public void setup() { - this.config = new DelegatingWebSocketConfiguration(); - this.context = new GenericWebApplicationContext(); - this.context.refresh(); + @Test + public void registerWebSocketHandler() throws Exception { + + AnnotationConfigWebApplicationContext cxt = new AnnotationConfigWebApplicationContext(); + cxt.register(TestWebSocketConfigurer.class, getUpgradeStrategyConfigClass()); + + this.server.init(cxt); + this.server.start(); + + WebSocketHandler clientHandler = Mockito.mock(WebSocketHandler.class); + WebSocketHandler serverHandler = cxt.getBean(WebSocketHandler.class); + + this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/ws"); + + verify(serverHandler).afterConnectionEstablished(any(WebSocketSession.class)); + verify(clientHandler).afterConnectionEstablished(any(WebSocketSession.class)); } @Test - public void webSocket() throws Exception { + public void registerWebSocketHandlerWithSockJS() throws Exception { - final WebSocketHandler handler = new TextWebSocketHandlerAdapter(); + AnnotationConfigWebApplicationContext cxt = new AnnotationConfigWebApplicationContext(); + cxt.register(TestWebSocketConfigurer.class, getUpgradeStrategyConfigClass()); - WebSocketConfigurer configurer = new WebSocketConfigurer() { - @Override - public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { - registry.addHandler(handler, "/h1"); - } - }; + this.server.init(cxt); + this.server.start(); - this.config.setConfigurers(Arrays.asList(configurer)); - SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) this.config.webSocketHandlerMapping(); - hm.setApplicationContext(this.context); + WebSocketHandler clientHandler = Mockito.mock(WebSocketHandler.class); + WebSocketHandler serverHandler = cxt.getBean(WebSocketHandler.class); - Object actual = hm.getUrlMap().get("/h1"); + this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/sockjs/websocket"); - assertNotNull(actual); - assertEquals(WebSocketHttpRequestHandler.class, actual.getClass()); - assertEquals(1, hm.getUrlMap().size()); + verify(serverHandler).afterConnectionEstablished(any(WebSocketSession.class)); + verify(clientHandler).afterConnectionEstablished(any(WebSocketSession.class)); } - @Test - public void webSocketWithSockJS() throws Exception { - final WebSocketHandler handler = new TextWebSocketHandlerAdapter(); + @Configuration + @EnableWebSocket + static class TestWebSocketConfigurer implements WebSocketConfigurer { - WebSocketConfigurer configurer = new WebSocketConfigurer() { - @Override - public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { - registry.addHandler(handler, "/h1").withSockJS(); - } - }; + @Autowired + private HandshakeHandler handshakeHandler; // can't rely on classpath for server detection - this.config.setConfigurers(Arrays.asList(configurer)); - SimpleUrlHandlerMapping hm = (SimpleUrlHandlerMapping) this.config.webSocketHandlerMapping(); - hm.setApplicationContext(this.context); - Object actual = hm.getUrlMap().get("/h1/**"); + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { - assertNotNull(actual); - assertEquals(SockJsHttpRequestHandler.class, actual.getClass()); - assertEquals(1, hm.getUrlMap().size()); + registry.addHandler(serverHandler(), "/ws") + .setHandshakeHandler(this.handshakeHandler); + + registry.addHandler(serverHandler(), "/sockjs").withSockJS() + .setTransportHandlerOverrides(new WebSocketTransportHandler(this.handshakeHandler)); + } + + @Bean + public WebSocketHandler serverHandler() { + return Mockito.mock(WebSocketHandler.class); + } } } diff --git a/spring-websocket/src/test/resources/log4j.xml b/spring-websocket/src/test/resources/log4j.xml index 4c016a6267..e72ac622e9 100644 --- a/spring-websocket/src/test/resources/log4j.xml +++ b/spring-websocket/src/test/resources/log4j.xml @@ -9,6 +9,10 @@ + + + +