Introduce ListenableFuture to WebSocketClient

Issue: SPR-10888
This commit is contained in:
Rossen Stoyanchev 2013-09-06 12:28:21 -04:00
parent 71e76196fe
commit 62921683fd
16 changed files with 234 additions and 149 deletions

View File

@ -83,7 +83,7 @@ public class AnnotationMethodIntegrationTests extends AbstractWebSocketIntegrati
public void simpleController() throws Exception {
TextMessage message = create(StompCommand.SEND).headers("destination:/app/simple").build();
WebSocketSession session = doHandshake(new TestClientWebSocketHandler(0, message), "/ws");
WebSocketSession session = doHandshake(new TestClientWebSocketHandler(0, message), "/ws").get();
SimpleController controller = this.wac.getBean(SimpleController.class);
try {
@ -104,7 +104,7 @@ public class AnnotationMethodIntegrationTests extends AbstractWebSocketIntegrati
"destination:/app/topic/increment").body("5").build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, message1, message2);
WebSocketSession session = doHandshake(clientHandler, "/ws");
WebSocketSession session = doHandshake(clientHandler, "/ws").get();
try {
assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS));

View File

@ -28,6 +28,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.util.UriComponentsBuilder;
@ -60,8 +61,8 @@ public abstract class AbstractWebSocketClient implements WebSocketClient {
@Override
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate,
Object... uriVars) throws WebSocketConnectFailureException {
public ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
String uriTemplate, Object... uriVars) {
Assert.notNull(uriTemplate, "uriTemplate must not be null");
URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri();
@ -69,8 +70,8 @@ public abstract class AbstractWebSocketClient implements WebSocketClient {
}
@Override
public final WebSocketSession doHandshake(WebSocketHandler webSocketHandler,
HttpHeaders headers, URI uri) throws WebSocketConnectFailureException {
public final ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
HttpHeaders headers, URI uri) {
Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
Assert.notNull(uri, "uri must not be null");
@ -111,12 +112,9 @@ public abstract class AbstractWebSocketClient implements WebSocketClient {
* @param handshakeAttributes attributes to make available via
* {@link WebSocketSession#getHandshakeAttributes()}; currently always an empty map.
*
* @return the established WebSocket session
*
* @throws WebSocketConnectFailureException
* @return the established WebSocket session wrapped in a ListenableFuture.
*/
protected abstract WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler,
HttpHeaders headers, URI uri, List<String> subProtocols,
Map<String, Object> handshakeAttributes) throws WebSocketConnectFailureException;
protected abstract ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
HttpHeaders headers, URI uri, List<String> subProtocols, Map<String, Object> handshakeAttributes);
}

View File

@ -21,8 +21,6 @@ import java.net.URI;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.SmartLifecycle;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.web.util.UriComponentsBuilder;
/**
@ -48,8 +46,6 @@ public abstract class ConnectionManagerSupport implements SmartLifecycle {
private int phase = Integer.MAX_VALUE;
private final TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("EndpointConnectionManager-");
private final Object lifecycleMonitor = new Object();
@ -126,28 +122,16 @@ public abstract class ConnectionManagerSupport implements SmartLifecycle {
}
protected void startInternal() {
if (logger.isDebugEnabled()) {
logger.debug("Starting " + this.getClass().getSimpleName());
}
this.isRunning = true;
this.taskExecutor.execute(new Runnable() {
@Override
public void run() {
synchronized (lifecycleMonitor) {
try {
logger.info("Connecting to WebSocket at " + uri);
openConnection();
logger.info("Successfully connected");
}
catch (Throwable ex) {
logger.error("Failed to connect", ex);
}
}
synchronized (lifecycleMonitor) {
if (logger.isDebugEnabled()) {
logger.debug("Starting " + this.getClass().getSimpleName());
}
});
this.isRunning = true;
openConnection();
}
}
protected abstract void openConnection() throws Exception;
protected abstract void openConnection();
@Override
public final void stop() {

View File

@ -19,6 +19,7 @@ package org.springframework.web.socket.client;
import java.net.URI;
import org.springframework.http.HttpHeaders;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
@ -34,10 +35,9 @@ import org.springframework.web.socket.WebSocketSession;
*/
public interface WebSocketClient {
WebSocketSession doHandshake(WebSocketHandler webSocketHandler,
String uriTemplate, Object... uriVariables) throws WebSocketConnectFailureException;
ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
String uriTemplate, Object... uriVariables);
WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri)
throws WebSocketConnectFailureException;
ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri);
}

View File

@ -1,38 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.client;
import org.springframework.core.NestedRuntimeException;
/**
* Thrown when a WebSocket connection to a server could not be established.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
@SuppressWarnings("serial")
public class WebSocketConnectFailureException extends NestedRuntimeException {
public WebSocketConnectFailureException(String msg, Throwable cause) {
super(msg, cause);
}
public WebSocketConnectFailureException(String msg) {
super(msg);
}
}

View File

@ -20,6 +20,8 @@ import java.util.List;
import org.springframework.context.SmartLifecycle;
import org.springframework.http.HttpHeaders;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator;
@ -129,8 +131,24 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
}
@Override
protected void openConnection() throws Exception {
this.webSocketSession = this.client.doHandshake(this.webSocketHandler, this.headers, getUri());
protected void openConnection() {
logger.info("Connecting to WebSocket at " + getUri());
ListenableFuture<WebSocketSession> future =
this.client.doHandshake(this.webSocketHandler, this.headers, getUri());
future.addCallback(new ListenableFutureCallback<WebSocketSession>() {
@Override
public void onSuccess(WebSocketSession result) {
webSocketSession = result;
logger.info("Successfully connected");
}
@Override
public void onFailure(Throwable t) {
logger.error("Failed to connect", t);
}
});
}
@Override

View File

@ -24,6 +24,9 @@ import javax.websocket.server.ServerEndpoint;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.util.Assert;
import org.springframework.web.socket.client.ConnectionManagerSupport;
import org.springframework.web.socket.support.BeanCreatingHandlerProvider;
@ -46,6 +49,8 @@ public class AnnotatedEndpointConnectionManager extends ConnectionManagerSupport
private Session session;
private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("AnnotatedEndpointConnectionManager-");
public AnnotatedEndpointConnectionManager(Object endpoint, String uriTemplate, Object... uriVariables) {
super(uriTemplate, uriVariables);
@ -75,10 +80,38 @@ public class AnnotatedEndpointConnectionManager extends ConnectionManagerSupport
}
}
/**
* Set a {@link TaskExecutor} to use to open the connection.
* By default {@link SimpleAsyncTaskExecutor} is used.
*/
public void setTaskExecutor(TaskExecutor taskExecutor) {
Assert.notNull(taskExecutor, "taskExecutor is required");
this.taskExecutor = taskExecutor;
}
/**
* Return the configured {@link TaskExecutor}.
*/
public TaskExecutor getTaskExecutor() {
return this.taskExecutor;
}
@Override
protected void openConnection() throws Exception {
Object endpoint = (this.endpoint != null) ? this.endpoint : this.endpointProvider.getHandler();
this.session = this.webSocketContainer.connectToServer(endpoint, getUri());
protected void openConnection() {
this.taskExecutor.execute(new Runnable() {
@Override
public void run() {
try {
logger.info("Connecting to WebSocket at " + getUri());
Object endpointToUse = (endpoint != null) ? endpoint : endpointProvider.getHandler();
session = webSocketContainer.connectToServer(endpointToUse, getUri());
logger.info("Successfully connected");
}
catch (Throwable ex) {
logger.error("Failed to connect", ex);
}
}
});
}
@Override

View File

@ -32,6 +32,8 @@ import javax.websocket.WebSocketContainer;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.util.Assert;
import org.springframework.web.socket.client.ConnectionManagerSupport;
import org.springframework.web.socket.support.BeanCreatingHandlerProvider;
@ -58,6 +60,8 @@ public class EndpointConnectionManager extends ConnectionManagerSupport implemen
private Session session;
private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("EndpointConnectionManager-");
public EndpointConnectionManager(Endpoint endpoint, String uriTemplate, Object... uriVariables) {
super(uriTemplate, uriVariables);
@ -109,11 +113,40 @@ public class EndpointConnectionManager extends ConnectionManagerSupport implemen
}
}
/**
* Set a {@link TaskExecutor} to use to open connections.
* By default {@link SimpleAsyncTaskExecutor} is used.
*/
public void setTaskExecutor(TaskExecutor taskExecutor) {
Assert.notNull(taskExecutor, "taskExecutor is required");
this.taskExecutor = taskExecutor;
}
/**
* Return the configured {@link TaskExecutor}.
*/
public TaskExecutor getTaskExecutor() {
return this.taskExecutor;
}
@Override
protected void openConnection() throws Exception {
Endpoint endpoint = (this.endpoint != null) ? this.endpoint : this.endpointProvider.getHandler();
ClientEndpointConfig endpointConfig = this.configBuilder.build();
this.session = getWebSocketContainer().connectToServer(endpoint, endpointConfig, getUri());
protected void openConnection() {
this.taskExecutor.execute(new Runnable() {
@Override
public void run() {
try {
logger.info("Connecting to WebSocket at " + getUri());
Endpoint endpointToUse = (endpoint != null) ? endpoint : endpointProvider.getHandler();
ClientEndpointConfig endpointConfig = configBuilder.build();
session = getWebSocketContainer().connectToServer(endpointToUse, endpointConfig, getUri());
logger.info("Successfully connected");
}
catch (Throwable ex) {
logger.error("Failed to connect", ex);
}
}
});
}
@Override

View File

@ -23,6 +23,7 @@ import java.net.UnknownHostException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.Callable;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ClientEndpointConfig.Configurator;
@ -31,14 +32,17 @@ import javax.websocket.Endpoint;
import javax.websocket.HandshakeResponse;
import javax.websocket.WebSocketContainer;
import org.springframework.core.task.AsyncListenableTaskExecutor;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketSession;
import org.springframework.web.socket.client.AbstractWebSocketClient;
import org.springframework.web.socket.client.WebSocketConnectFailureException;
/**
* Initiates WebSocket requests to a WebSocket server programatically through the standard
@ -51,6 +55,9 @@ public class StandardWebSocketClient extends AbstractWebSocketClient {
private final WebSocketContainer webSocketContainer;
private AsyncListenableTaskExecutor taskExecutor =
new SimpleAsyncTaskExecutor("WebSocketClient-");
/**
* Default constructor that calls {@code ContainerProvider.getWebSocketContainer()} to
@ -71,31 +78,45 @@ public class StandardWebSocketClient extends AbstractWebSocketClient {
this.webSocketContainer = webSocketContainer;
}
/**
* Set a {@link TaskExecutor} to use to open the connection.
* By default {@link SimpleAsyncTaskExecutor} is used.
*/
public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) {
Assert.notNull(taskExecutor, "taskExecutor is required");
this.taskExecutor = taskExecutor;
}
/**
* Return the configured {@link TaskExecutor}.
*/
public AsyncListenableTaskExecutor getTaskExecutor() {
return this.taskExecutor;
}
@Override
protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler,
HttpHeaders headers, URI uri, List<String> protocols,
Map<String, Object> handshakeAttributes) throws WebSocketConnectFailureException {
protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
HttpHeaders headers, final URI uri, List<String> protocols, Map<String, Object> handshakeAttributes) {
int port = getPort(uri);
InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port);
InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port);
StandardWebSocketSession session = new StandardWebSocketSession(headers,
final StandardWebSocketSession session = new StandardWebSocketSession(headers,
handshakeAttributes, localAddress, remoteAddress);
ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create();
final ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create();
configBuidler.configurator(new StandardWebSocketClientConfigurator(headers));
configBuidler.preferredSubprotocols(protocols);
final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
try {
Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
this.webSocketContainer.connectToServer(endpoint, configBuidler.build(), uri);
return session;
}
catch (Exception e) {
throw new WebSocketConnectFailureException("Failed to connect to " + uri, e);
}
return this.taskExecutor.submitListenable(new Callable<WebSocketSession>() {
@Override
public WebSocketSession call() throws Exception {
webSocketContainer.connectToServer(endpoint, configBuidler.build(), uri);
return session;
}
});
}
private InetAddress getLocalHost() {

View File

@ -20,19 +20,24 @@ import java.net.URI;
import java.security.Principal;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.springframework.context.SmartLifecycle;
import org.springframework.core.task.AsyncListenableTaskExecutor;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSession;
import org.springframework.web.socket.client.AbstractWebSocketClient;
import org.springframework.web.socket.client.WebSocketConnectFailureException;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
@ -53,6 +58,8 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma
private final Object lifecycleMonitor = new Object();
private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("WebSocketClient-");
/**
* Default constructor that creates an instance of
@ -71,6 +78,22 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma
}
/**
* Set a {@link TaskExecutor} to use to open the connection.
* By default {@link SimpleAsyncTaskExecutor} is used.
*/
public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) {
Assert.notNull(taskExecutor, "taskExecutor is required");
this.taskExecutor = taskExecutor;
}
/**
* Return the configured {@link TaskExecutor}.
*/
public AsyncListenableTaskExecutor getTaskExecutor() {
return this.taskExecutor;
}
public void setAutoStartup(boolean autoStartup) {
this.autoStartup = autoStartup;
}
@ -137,39 +160,37 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma
}
@Override
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVars)
throws WebSocketConnectFailureException {
public ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
String uriTemplate, Object... uriVars) {
UriComponents uriComponents = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode();
return doHandshake(webSocketHandler, null, uriComponents.toUri());
}
@Override
public WebSocketSession doHandshakeInternal(WebSocketHandler wsHandler, HttpHeaders headers,
URI uri, List<String> protocols, Map<String, Object> handshakeAttributes)
throws WebSocketConnectFailureException {
public ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler wsHandler,
HttpHeaders headers, final URI uri, List<String> protocols, Map<String, Object> handshakeAttributes) {
ClientUpgradeRequest request = new ClientUpgradeRequest();
final ClientUpgradeRequest request = new ClientUpgradeRequest();
request.setSubProtocols(protocols);
for (String header : headers.keySet()) {
request.setHeader(header, headers.get(header));
}
Principal user = getUser();
JettyWebSocketSession wsSession = new JettyWebSocketSession(user, handshakeAttributes);
JettyWebSocketHandlerAdapter listener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession);
final JettyWebSocketSession wsSession = new JettyWebSocketSession(user, handshakeAttributes);
final JettyWebSocketHandlerAdapter listener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession);
try {
Future<Session> future = this.client.connect(listener, uri, request);
future.get();
return wsSession;
}
catch (Exception e) {
throw new WebSocketConnectFailureException("Failed to connect to " + uri, e);
}
return this.taskExecutor.submitListenable(new Callable<WebSocketSession>() {
@Override
public WebSocketSession call() throws Exception {
Future<Session> future = client.connect(listener, uri, request);
future.get();
return wsSession;
}
});
}
/**
* @return the user to make available through {@link WebSocketSession#getPrincipal()};
* by default this method returns {@code null}

View File

@ -26,6 +26,7 @@ import org.junit.runners.Parameterized.Parameter;
import org.springframework.context.Lifecycle;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
@ -108,7 +109,7 @@ public abstract class AbstractWebSocketIntegrationTests {
return "ws://localhost:" + this.server.getPort();
}
protected WebSocketSession doHandshake(WebSocketHandler clientHandler, String endpointPath) {
protected ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler clientHandler, String endpointPath) {
return this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + endpointPath);
}

View File

@ -25,7 +25,6 @@ import org.apache.catalina.startup.Tomcat;
import org.apache.coyote.http11.Http11NioProtocol;
import org.apache.tomcat.util.descriptor.web.ApplicationListener;
import org.apache.tomcat.websocket.server.WsListener;
import org.springframework.core.NestedRuntimeException;
import org.springframework.util.SocketUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;
@ -74,7 +73,7 @@ public class TomcatWebSocketTestServer implements WebSocketTestServer {
return tempFolder;
}
catch (IOException ex) {
throw new NestedRuntimeException("Unable to create temp directory", ex) {};
throw new RuntimeException("Unable to create temp directory", ex);
}
}

View File

@ -19,19 +19,21 @@ package org.springframework.web.socket.client;
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.context.SmartLifecycle;
import org.springframework.http.HttpHeaders;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureTask;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.WebSocketHandlerAdapter;
import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator;
import org.springframework.web.socket.support.WebSocketHandlerDecorator;
import org.springframework.web.util.UriComponentsBuilder;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
/**
* Test fixture for {@link WebSocketConnectionManager}.
@ -45,26 +47,20 @@ public class WebSocketConnectionManagerTests {
List<String> subprotocols = Arrays.asList("abc");
WebSocketClient client = mock(WebSocketClient.class);
TestLifecycleWebSocketClient client = new TestLifecycleWebSocketClient(false);
WebSocketHandler handler = new WebSocketHandlerAdapter();
WebSocketConnectionManager manager = new WebSocketConnectionManager(client, handler , "/path/{id}", "123");
manager.setSubProtocols(subprotocols);
manager.openConnection();
ArgumentCaptor<WebSocketHandlerDecorator> captor = ArgumentCaptor.forClass(WebSocketHandlerDecorator.class);
ArgumentCaptor<HttpHeaders> headersCaptor = ArgumentCaptor.forClass(HttpHeaders.class);
ArgumentCaptor<URI> uriCaptor = ArgumentCaptor.forClass(URI.class);
verify(client).doHandshake(captor.capture(), headersCaptor.capture(), uriCaptor.capture());
HttpHeaders expectedHeaders = new HttpHeaders();
expectedHeaders.setSecWebSocketProtocol(subprotocols);
assertEquals(expectedHeaders, headersCaptor.getValue());
assertEquals(new URI("/path/123"), uriCaptor.getValue());
assertEquals(expectedHeaders, client.headers);
assertEquals(new URI("/path/123"), client.uri);
WebSocketHandlerDecorator loggingHandler = captor.getValue();
WebSocketHandlerDecorator loggingHandler = (WebSocketHandlerDecorator) client.webSocketHandler;
assertEquals(LoggingWebSocketHandlerDecorator.class, loggingHandler.getClass());
assertSame(handler, loggingHandler.getDelegate());
@ -103,6 +99,13 @@ public class WebSocketConnectionManagerTests {
private boolean running;
private WebSocketHandler webSocketHandler;
private HttpHeaders headers;
private URI uri;
public TestLifecycleWebSocketClient(boolean running) {
this.running = running;
}
@ -138,15 +141,27 @@ public class WebSocketConnectionManagerTests {
}
@Override
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVariables)
throws WebSocketConnectFailureException {
return null;
public ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
String uriTemplate, Object... uriVars) {
URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri();
return doHandshake(webSocketHandler, null, uri);
}
@Override
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri)
throws WebSocketConnectFailureException {
return null;
public ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
HttpHeaders headers, URI uri) {
this.webSocketHandler = webSocketHandler;
this.headers = headers;
this.uri = uri;
return new ListenableFutureTask<WebSocketSession>(new Callable<WebSocketSession>() {
@Override
public WebSocketSession call() throws Exception {
return null;
}
});
}
}

View File

@ -66,7 +66,7 @@ public class StandardWebSocketClientTests {
@Test
public void localAddress() throws Exception {
URI uri = new URI("ws://example.com/abc");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();
assertNotNull(session.getLocalAddress());
assertEquals(80, session.getLocalAddress().getPort());
@ -75,7 +75,7 @@ public class StandardWebSocketClientTests {
@Test
public void localAddressWss() throws Exception {
URI uri = new URI("wss://example.com/abc");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();
assertNotNull(session.getLocalAddress());
assertEquals(443, session.getLocalAddress().getPort());
@ -90,7 +90,7 @@ public class StandardWebSocketClientTests {
@Test
public void remoteAddress() throws Exception {
URI uri = new URI("wss://example.com/abc");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();
assertNotNull(session.getRemoteAddress());
assertEquals("example.com", session.getRemoteAddress().getHostName());
@ -105,7 +105,7 @@ public class StandardWebSocketClientTests {
this.headers.setSecWebSocketProtocol(protocols);
this.headers.add("foo", "bar");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();
assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), session.getHandshakeHeaders());
}
@ -118,7 +118,7 @@ public class StandardWebSocketClientTests {
this.headers.setSecWebSocketProtocol(protocols);
this.headers.add("foo", "bar");
this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();
ArgumentCaptor<Endpoint> arg1 = ArgumentCaptor.forClass(Endpoint.class);
ArgumentCaptor<ClientEndpointConfig> arg2 = ArgumentCaptor.forClass(ClientEndpointConfig.class);

View File

@ -83,7 +83,7 @@ public class JettyWebSocketClientTests {
HttpHeaders headers = new HttpHeaders();
headers.setSecWebSocketProtocol(Arrays.asList("echo"));
this.wsSession = this.client.doHandshake(new TextWebSocketHandlerAdapter(), headers, new URI(this.wsUrl));
this.wsSession = this.client.doHandshake(new TextWebSocketHandlerAdapter(), headers, new URI(this.wsUrl)).get();
assertEquals(this.wsUrl, this.wsSession.getUri().toString());
assertEquals("echo", this.wsSession.getAcceptedProtocol());

View File

@ -65,8 +65,8 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes
@Test
public void registerWebSocketHandler() throws Exception {
WebSocketSession session =
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/ws");
WebSocketSession session = this.webSocketClient.doHandshake(
new WebSocketHandlerAdapter(), getWsBaseUrl() + "/ws").get();
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
@ -77,8 +77,8 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes
@Test
public void registerWebSocketHandlerWithSockJS() throws Exception {
WebSocketSession session =
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/sockjs/websocket");
WebSocketSession session = this.webSocketClient.doHandshake(
new WebSocketHandlerAdapter(), getWsBaseUrl() + "/sockjs/websocket").get();
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));