Introduce ListenableFuture to WebSocketClient
Issue: SPR-10888
This commit is contained in:
parent
71e76196fe
commit
62921683fd
|
|
@ -83,7 +83,7 @@ public class AnnotationMethodIntegrationTests extends AbstractWebSocketIntegrati
|
|||
public void simpleController() throws Exception {
|
||||
|
||||
TextMessage message = create(StompCommand.SEND).headers("destination:/app/simple").build();
|
||||
WebSocketSession session = doHandshake(new TestClientWebSocketHandler(0, message), "/ws");
|
||||
WebSocketSession session = doHandshake(new TestClientWebSocketHandler(0, message), "/ws").get();
|
||||
|
||||
SimpleController controller = this.wac.getBean(SimpleController.class);
|
||||
try {
|
||||
|
|
@ -104,7 +104,7 @@ public class AnnotationMethodIntegrationTests extends AbstractWebSocketIntegrati
|
|||
"destination:/app/topic/increment").body("5").build();
|
||||
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, message1, message2);
|
||||
WebSocketSession session = doHandshake(clientHandler, "/ws");
|
||||
WebSocketSession session = doHandshake(clientHandler, "/ws").get();
|
||||
|
||||
try {
|
||||
assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS));
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import org.apache.commons.logging.Log;
|
|||
import org.apache.commons.logging.LogFactory;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.concurrent.ListenableFuture;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
|
@ -60,8 +61,8 @@ public abstract class AbstractWebSocketClient implements WebSocketClient {
|
|||
|
||||
|
||||
@Override
|
||||
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate,
|
||||
Object... uriVars) throws WebSocketConnectFailureException {
|
||||
public ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
|
||||
String uriTemplate, Object... uriVars) {
|
||||
|
||||
Assert.notNull(uriTemplate, "uriTemplate must not be null");
|
||||
URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri();
|
||||
|
|
@ -69,8 +70,8 @@ public abstract class AbstractWebSocketClient implements WebSocketClient {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final WebSocketSession doHandshake(WebSocketHandler webSocketHandler,
|
||||
HttpHeaders headers, URI uri) throws WebSocketConnectFailureException {
|
||||
public final ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
|
||||
HttpHeaders headers, URI uri) {
|
||||
|
||||
Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
|
||||
Assert.notNull(uri, "uri must not be null");
|
||||
|
|
@ -111,12 +112,9 @@ public abstract class AbstractWebSocketClient implements WebSocketClient {
|
|||
* @param handshakeAttributes attributes to make available via
|
||||
* {@link WebSocketSession#getHandshakeAttributes()}; currently always an empty map.
|
||||
*
|
||||
* @return the established WebSocket session
|
||||
*
|
||||
* @throws WebSocketConnectFailureException
|
||||
* @return the established WebSocket session wrapped in a ListenableFuture.
|
||||
*/
|
||||
protected abstract WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler,
|
||||
HttpHeaders headers, URI uri, List<String> subProtocols,
|
||||
Map<String, Object> handshakeAttributes) throws WebSocketConnectFailureException;
|
||||
protected abstract ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
|
||||
HttpHeaders headers, URI uri, List<String> subProtocols, Map<String, Object> handshakeAttributes);
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,8 +21,6 @@ import java.net.URI;
|
|||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.springframework.context.SmartLifecycle;
|
||||
import org.springframework.core.task.SimpleAsyncTaskExecutor;
|
||||
import org.springframework.core.task.TaskExecutor;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
/**
|
||||
|
|
@ -48,8 +46,6 @@ public abstract class ConnectionManagerSupport implements SmartLifecycle {
|
|||
|
||||
private int phase = Integer.MAX_VALUE;
|
||||
|
||||
private final TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("EndpointConnectionManager-");
|
||||
|
||||
private final Object lifecycleMonitor = new Object();
|
||||
|
||||
|
||||
|
|
@ -126,28 +122,16 @@ public abstract class ConnectionManagerSupport implements SmartLifecycle {
|
|||
}
|
||||
|
||||
protected void startInternal() {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Starting " + this.getClass().getSimpleName());
|
||||
}
|
||||
this.isRunning = true;
|
||||
this.taskExecutor.execute(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
synchronized (lifecycleMonitor) {
|
||||
try {
|
||||
logger.info("Connecting to WebSocket at " + uri);
|
||||
openConnection();
|
||||
logger.info("Successfully connected");
|
||||
}
|
||||
catch (Throwable ex) {
|
||||
logger.error("Failed to connect", ex);
|
||||
}
|
||||
}
|
||||
synchronized (lifecycleMonitor) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Starting " + this.getClass().getSimpleName());
|
||||
}
|
||||
});
|
||||
this.isRunning = true;
|
||||
openConnection();
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract void openConnection() throws Exception;
|
||||
protected abstract void openConnection();
|
||||
|
||||
@Override
|
||||
public final void stop() {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ package org.springframework.web.socket.client;
|
|||
import java.net.URI;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.util.concurrent.ListenableFuture;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
|
||||
|
|
@ -34,10 +35,9 @@ import org.springframework.web.socket.WebSocketSession;
|
|||
*/
|
||||
public interface WebSocketClient {
|
||||
|
||||
WebSocketSession doHandshake(WebSocketHandler webSocketHandler,
|
||||
String uriTemplate, Object... uriVariables) throws WebSocketConnectFailureException;
|
||||
ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
|
||||
String uriTemplate, Object... uriVariables);
|
||||
|
||||
WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri)
|
||||
throws WebSocketConnectFailureException;
|
||||
ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri);
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.socket.client;
|
||||
|
||||
import org.springframework.core.NestedRuntimeException;
|
||||
|
||||
/**
|
||||
* Thrown when a WebSocket connection to a server could not be established.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 4.0
|
||||
*/
|
||||
@SuppressWarnings("serial")
|
||||
public class WebSocketConnectFailureException extends NestedRuntimeException {
|
||||
|
||||
public WebSocketConnectFailureException(String msg, Throwable cause) {
|
||||
super(msg, cause);
|
||||
}
|
||||
|
||||
public WebSocketConnectFailureException(String msg) {
|
||||
super(msg);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -20,6 +20,8 @@ import java.util.List;
|
|||
|
||||
import org.springframework.context.SmartLifecycle;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.util.concurrent.ListenableFuture;
|
||||
import org.springframework.util.concurrent.ListenableFutureCallback;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator;
|
||||
|
|
@ -129,8 +131,24 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected void openConnection() throws Exception {
|
||||
this.webSocketSession = this.client.doHandshake(this.webSocketHandler, this.headers, getUri());
|
||||
protected void openConnection() {
|
||||
|
||||
logger.info("Connecting to WebSocket at " + getUri());
|
||||
|
||||
ListenableFuture<WebSocketSession> future =
|
||||
this.client.doHandshake(this.webSocketHandler, this.headers, getUri());
|
||||
|
||||
future.addCallback(new ListenableFutureCallback<WebSocketSession>() {
|
||||
@Override
|
||||
public void onSuccess(WebSocketSession result) {
|
||||
webSocketSession = result;
|
||||
logger.info("Successfully connected");
|
||||
}
|
||||
@Override
|
||||
public void onFailure(Throwable t) {
|
||||
logger.error("Failed to connect", t);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ import javax.websocket.server.ServerEndpoint;
|
|||
import org.springframework.beans.BeansException;
|
||||
import org.springframework.beans.factory.BeanFactory;
|
||||
import org.springframework.beans.factory.BeanFactoryAware;
|
||||
import org.springframework.core.task.SimpleAsyncTaskExecutor;
|
||||
import org.springframework.core.task.TaskExecutor;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.socket.client.ConnectionManagerSupport;
|
||||
import org.springframework.web.socket.support.BeanCreatingHandlerProvider;
|
||||
|
||||
|
|
@ -46,6 +49,8 @@ public class AnnotatedEndpointConnectionManager extends ConnectionManagerSupport
|
|||
|
||||
private Session session;
|
||||
|
||||
private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("AnnotatedEndpointConnectionManager-");
|
||||
|
||||
|
||||
public AnnotatedEndpointConnectionManager(Object endpoint, String uriTemplate, Object... uriVariables) {
|
||||
super(uriTemplate, uriVariables);
|
||||
|
|
@ -75,10 +80,38 @@ public class AnnotatedEndpointConnectionManager extends ConnectionManagerSupport
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a {@link TaskExecutor} to use to open the connection.
|
||||
* By default {@link SimpleAsyncTaskExecutor} is used.
|
||||
*/
|
||||
public void setTaskExecutor(TaskExecutor taskExecutor) {
|
||||
Assert.notNull(taskExecutor, "taskExecutor is required");
|
||||
this.taskExecutor = taskExecutor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the configured {@link TaskExecutor}.
|
||||
*/
|
||||
public TaskExecutor getTaskExecutor() {
|
||||
return this.taskExecutor;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void openConnection() throws Exception {
|
||||
Object endpoint = (this.endpoint != null) ? this.endpoint : this.endpointProvider.getHandler();
|
||||
this.session = this.webSocketContainer.connectToServer(endpoint, getUri());
|
||||
protected void openConnection() {
|
||||
this.taskExecutor.execute(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
logger.info("Connecting to WebSocket at " + getUri());
|
||||
Object endpointToUse = (endpoint != null) ? endpoint : endpointProvider.getHandler();
|
||||
session = webSocketContainer.connectToServer(endpointToUse, getUri());
|
||||
logger.info("Successfully connected");
|
||||
}
|
||||
catch (Throwable ex) {
|
||||
logger.error("Failed to connect", ex);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ import javax.websocket.WebSocketContainer;
|
|||
import org.springframework.beans.BeansException;
|
||||
import org.springframework.beans.factory.BeanFactory;
|
||||
import org.springframework.beans.factory.BeanFactoryAware;
|
||||
import org.springframework.core.task.SimpleAsyncTaskExecutor;
|
||||
import org.springframework.core.task.TaskExecutor;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.socket.client.ConnectionManagerSupport;
|
||||
import org.springframework.web.socket.support.BeanCreatingHandlerProvider;
|
||||
|
|
@ -58,6 +60,8 @@ public class EndpointConnectionManager extends ConnectionManagerSupport implemen
|
|||
|
||||
private Session session;
|
||||
|
||||
private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("EndpointConnectionManager-");
|
||||
|
||||
|
||||
public EndpointConnectionManager(Endpoint endpoint, String uriTemplate, Object... uriVariables) {
|
||||
super(uriTemplate, uriVariables);
|
||||
|
|
@ -109,11 +113,40 @@ public class EndpointConnectionManager extends ConnectionManagerSupport implemen
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a {@link TaskExecutor} to use to open connections.
|
||||
* By default {@link SimpleAsyncTaskExecutor} is used.
|
||||
*/
|
||||
public void setTaskExecutor(TaskExecutor taskExecutor) {
|
||||
Assert.notNull(taskExecutor, "taskExecutor is required");
|
||||
this.taskExecutor = taskExecutor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the configured {@link TaskExecutor}.
|
||||
*/
|
||||
public TaskExecutor getTaskExecutor() {
|
||||
return this.taskExecutor;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected void openConnection() throws Exception {
|
||||
Endpoint endpoint = (this.endpoint != null) ? this.endpoint : this.endpointProvider.getHandler();
|
||||
ClientEndpointConfig endpointConfig = this.configBuilder.build();
|
||||
this.session = getWebSocketContainer().connectToServer(endpoint, endpointConfig, getUri());
|
||||
protected void openConnection() {
|
||||
this.taskExecutor.execute(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
logger.info("Connecting to WebSocket at " + getUri());
|
||||
Endpoint endpointToUse = (endpoint != null) ? endpoint : endpointProvider.getHandler();
|
||||
ClientEndpointConfig endpointConfig = configBuilder.build();
|
||||
session = getWebSocketContainer().connectToServer(endpointToUse, endpointConfig, getUri());
|
||||
logger.info("Successfully connected");
|
||||
}
|
||||
catch (Throwable ex) {
|
||||
logger.error("Failed to connect", ex);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import java.net.UnknownHostException;
|
|||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
import javax.websocket.ClientEndpointConfig;
|
||||
import javax.websocket.ClientEndpointConfig.Configurator;
|
||||
|
|
@ -31,14 +32,17 @@ import javax.websocket.Endpoint;
|
|||
import javax.websocket.HandshakeResponse;
|
||||
import javax.websocket.WebSocketContainer;
|
||||
|
||||
import org.springframework.core.task.AsyncListenableTaskExecutor;
|
||||
import org.springframework.core.task.SimpleAsyncTaskExecutor;
|
||||
import org.springframework.core.task.TaskExecutor;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.concurrent.ListenableFuture;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter;
|
||||
import org.springframework.web.socket.adapter.StandardWebSocketSession;
|
||||
import org.springframework.web.socket.client.AbstractWebSocketClient;
|
||||
import org.springframework.web.socket.client.WebSocketConnectFailureException;
|
||||
|
||||
/**
|
||||
* Initiates WebSocket requests to a WebSocket server programatically through the standard
|
||||
|
|
@ -51,6 +55,9 @@ public class StandardWebSocketClient extends AbstractWebSocketClient {
|
|||
|
||||
private final WebSocketContainer webSocketContainer;
|
||||
|
||||
private AsyncListenableTaskExecutor taskExecutor =
|
||||
new SimpleAsyncTaskExecutor("WebSocketClient-");
|
||||
|
||||
|
||||
/**
|
||||
* Default constructor that calls {@code ContainerProvider.getWebSocketContainer()} to
|
||||
|
|
@ -71,31 +78,45 @@ public class StandardWebSocketClient extends AbstractWebSocketClient {
|
|||
this.webSocketContainer = webSocketContainer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a {@link TaskExecutor} to use to open the connection.
|
||||
* By default {@link SimpleAsyncTaskExecutor} is used.
|
||||
*/
|
||||
public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) {
|
||||
Assert.notNull(taskExecutor, "taskExecutor is required");
|
||||
this.taskExecutor = taskExecutor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the configured {@link TaskExecutor}.
|
||||
*/
|
||||
public AsyncListenableTaskExecutor getTaskExecutor() {
|
||||
return this.taskExecutor;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler,
|
||||
HttpHeaders headers, URI uri, List<String> protocols,
|
||||
Map<String, Object> handshakeAttributes) throws WebSocketConnectFailureException {
|
||||
protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
|
||||
HttpHeaders headers, final URI uri, List<String> protocols, Map<String, Object> handshakeAttributes) {
|
||||
|
||||
int port = getPort(uri);
|
||||
InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port);
|
||||
InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port);
|
||||
|
||||
StandardWebSocketSession session = new StandardWebSocketSession(headers,
|
||||
final StandardWebSocketSession session = new StandardWebSocketSession(headers,
|
||||
handshakeAttributes, localAddress, remoteAddress);
|
||||
|
||||
ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create();
|
||||
final ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create();
|
||||
configBuidler.configurator(new StandardWebSocketClientConfigurator(headers));
|
||||
configBuidler.preferredSubprotocols(protocols);
|
||||
final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
|
||||
|
||||
try {
|
||||
Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
|
||||
this.webSocketContainer.connectToServer(endpoint, configBuidler.build(), uri);
|
||||
return session;
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new WebSocketConnectFailureException("Failed to connect to " + uri, e);
|
||||
}
|
||||
return this.taskExecutor.submitListenable(new Callable<WebSocketSession>() {
|
||||
@Override
|
||||
public WebSocketSession call() throws Exception {
|
||||
webSocketContainer.connectToServer(endpoint, configBuidler.build(), uri);
|
||||
return session;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private InetAddress getLocalHost() {
|
||||
|
|
|
|||
|
|
@ -20,19 +20,24 @@ import java.net.URI;
|
|||
import java.security.Principal;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
import org.eclipse.jetty.websocket.api.Session;
|
||||
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
|
||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||
import org.springframework.context.SmartLifecycle;
|
||||
import org.springframework.core.task.AsyncListenableTaskExecutor;
|
||||
import org.springframework.core.task.SimpleAsyncTaskExecutor;
|
||||
import org.springframework.core.task.TaskExecutor;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.concurrent.ListenableFuture;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter;
|
||||
import org.springframework.web.socket.adapter.JettyWebSocketSession;
|
||||
import org.springframework.web.socket.client.AbstractWebSocketClient;
|
||||
import org.springframework.web.socket.client.WebSocketConnectFailureException;
|
||||
import org.springframework.web.util.UriComponents;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
|
|
@ -53,6 +58,8 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma
|
|||
|
||||
private final Object lifecycleMonitor = new Object();
|
||||
|
||||
private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("WebSocketClient-");
|
||||
|
||||
|
||||
/**
|
||||
* Default constructor that creates an instance of
|
||||
|
|
@ -71,6 +78,22 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* Set a {@link TaskExecutor} to use to open the connection.
|
||||
* By default {@link SimpleAsyncTaskExecutor} is used.
|
||||
*/
|
||||
public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) {
|
||||
Assert.notNull(taskExecutor, "taskExecutor is required");
|
||||
this.taskExecutor = taskExecutor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the configured {@link TaskExecutor}.
|
||||
*/
|
||||
public AsyncListenableTaskExecutor getTaskExecutor() {
|
||||
return this.taskExecutor;
|
||||
}
|
||||
|
||||
public void setAutoStartup(boolean autoStartup) {
|
||||
this.autoStartup = autoStartup;
|
||||
}
|
||||
|
|
@ -137,39 +160,37 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma
|
|||
}
|
||||
|
||||
@Override
|
||||
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVars)
|
||||
throws WebSocketConnectFailureException {
|
||||
public ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
|
||||
String uriTemplate, Object... uriVars) {
|
||||
|
||||
UriComponents uriComponents = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode();
|
||||
return doHandshake(webSocketHandler, null, uriComponents.toUri());
|
||||
}
|
||||
|
||||
@Override
|
||||
public WebSocketSession doHandshakeInternal(WebSocketHandler wsHandler, HttpHeaders headers,
|
||||
URI uri, List<String> protocols, Map<String, Object> handshakeAttributes)
|
||||
throws WebSocketConnectFailureException {
|
||||
public ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler wsHandler,
|
||||
HttpHeaders headers, final URI uri, List<String> protocols, Map<String, Object> handshakeAttributes) {
|
||||
|
||||
ClientUpgradeRequest request = new ClientUpgradeRequest();
|
||||
final ClientUpgradeRequest request = new ClientUpgradeRequest();
|
||||
request.setSubProtocols(protocols);
|
||||
for (String header : headers.keySet()) {
|
||||
request.setHeader(header, headers.get(header));
|
||||
}
|
||||
|
||||
Principal user = getUser();
|
||||
JettyWebSocketSession wsSession = new JettyWebSocketSession(user, handshakeAttributes);
|
||||
JettyWebSocketHandlerAdapter listener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession);
|
||||
final JettyWebSocketSession wsSession = new JettyWebSocketSession(user, handshakeAttributes);
|
||||
final JettyWebSocketHandlerAdapter listener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession);
|
||||
|
||||
try {
|
||||
Future<Session> future = this.client.connect(listener, uri, request);
|
||||
future.get();
|
||||
return wsSession;
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new WebSocketConnectFailureException("Failed to connect to " + uri, e);
|
||||
}
|
||||
return this.taskExecutor.submitListenable(new Callable<WebSocketSession>() {
|
||||
@Override
|
||||
public WebSocketSession call() throws Exception {
|
||||
Future<Session> future = client.connect(listener, uri, request);
|
||||
future.get();
|
||||
return wsSession;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @return the user to make available through {@link WebSocketSession#getPrincipal()};
|
||||
* by default this method returns {@code null}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import org.junit.runners.Parameterized.Parameter;
|
|||
import org.springframework.context.Lifecycle;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.util.concurrent.ListenableFuture;
|
||||
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
|
||||
import org.springframework.web.socket.client.WebSocketClient;
|
||||
import org.springframework.web.socket.server.DefaultHandshakeHandler;
|
||||
|
|
@ -108,7 +109,7 @@ public abstract class AbstractWebSocketIntegrationTests {
|
|||
return "ws://localhost:" + this.server.getPort();
|
||||
}
|
||||
|
||||
protected WebSocketSession doHandshake(WebSocketHandler clientHandler, String endpointPath) {
|
||||
protected ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler clientHandler, String endpointPath) {
|
||||
return this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + endpointPath);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ import org.apache.catalina.startup.Tomcat;
|
|||
import org.apache.coyote.http11.Http11NioProtocol;
|
||||
import org.apache.tomcat.util.descriptor.web.ApplicationListener;
|
||||
import org.apache.tomcat.websocket.server.WsListener;
|
||||
import org.springframework.core.NestedRuntimeException;
|
||||
import org.springframework.util.SocketUtils;
|
||||
import org.springframework.web.context.WebApplicationContext;
|
||||
import org.springframework.web.servlet.DispatcherServlet;
|
||||
|
|
@ -74,7 +73,7 @@ public class TomcatWebSocketTestServer implements WebSocketTestServer {
|
|||
return tempFolder;
|
||||
}
|
||||
catch (IOException ex) {
|
||||
throw new NestedRuntimeException("Unable to create temp directory", ex) {};
|
||||
throw new RuntimeException("Unable to create temp directory", ex);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,19 +19,21 @@ package org.springframework.web.socket.client;
|
|||
import java.net.URI;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.springframework.context.SmartLifecycle;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.util.concurrent.ListenableFuture;
|
||||
import org.springframework.util.concurrent.ListenableFutureTask;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.adapter.WebSocketHandlerAdapter;
|
||||
import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator;
|
||||
import org.springframework.web.socket.support.WebSocketHandlerDecorator;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
/**
|
||||
* Test fixture for {@link WebSocketConnectionManager}.
|
||||
|
|
@ -45,26 +47,20 @@ public class WebSocketConnectionManagerTests {
|
|||
|
||||
List<String> subprotocols = Arrays.asList("abc");
|
||||
|
||||
WebSocketClient client = mock(WebSocketClient.class);
|
||||
TestLifecycleWebSocketClient client = new TestLifecycleWebSocketClient(false);
|
||||
WebSocketHandler handler = new WebSocketHandlerAdapter();
|
||||
|
||||
WebSocketConnectionManager manager = new WebSocketConnectionManager(client, handler , "/path/{id}", "123");
|
||||
manager.setSubProtocols(subprotocols);
|
||||
manager.openConnection();
|
||||
|
||||
ArgumentCaptor<WebSocketHandlerDecorator> captor = ArgumentCaptor.forClass(WebSocketHandlerDecorator.class);
|
||||
ArgumentCaptor<HttpHeaders> headersCaptor = ArgumentCaptor.forClass(HttpHeaders.class);
|
||||
ArgumentCaptor<URI> uriCaptor = ArgumentCaptor.forClass(URI.class);
|
||||
|
||||
verify(client).doHandshake(captor.capture(), headersCaptor.capture(), uriCaptor.capture());
|
||||
|
||||
HttpHeaders expectedHeaders = new HttpHeaders();
|
||||
expectedHeaders.setSecWebSocketProtocol(subprotocols);
|
||||
|
||||
assertEquals(expectedHeaders, headersCaptor.getValue());
|
||||
assertEquals(new URI("/path/123"), uriCaptor.getValue());
|
||||
assertEquals(expectedHeaders, client.headers);
|
||||
assertEquals(new URI("/path/123"), client.uri);
|
||||
|
||||
WebSocketHandlerDecorator loggingHandler = captor.getValue();
|
||||
WebSocketHandlerDecorator loggingHandler = (WebSocketHandlerDecorator) client.webSocketHandler;
|
||||
assertEquals(LoggingWebSocketHandlerDecorator.class, loggingHandler.getClass());
|
||||
|
||||
assertSame(handler, loggingHandler.getDelegate());
|
||||
|
|
@ -103,6 +99,13 @@ public class WebSocketConnectionManagerTests {
|
|||
|
||||
private boolean running;
|
||||
|
||||
private WebSocketHandler webSocketHandler;
|
||||
|
||||
private HttpHeaders headers;
|
||||
|
||||
private URI uri;
|
||||
|
||||
|
||||
public TestLifecycleWebSocketClient(boolean running) {
|
||||
this.running = running;
|
||||
}
|
||||
|
|
@ -138,15 +141,27 @@ public class WebSocketConnectionManagerTests {
|
|||
}
|
||||
|
||||
@Override
|
||||
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVariables)
|
||||
throws WebSocketConnectFailureException {
|
||||
return null;
|
||||
public ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
|
||||
String uriTemplate, Object... uriVars) {
|
||||
|
||||
URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri();
|
||||
return doHandshake(webSocketHandler, null, uri);
|
||||
}
|
||||
|
||||
@Override
|
||||
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri)
|
||||
throws WebSocketConnectFailureException {
|
||||
return null;
|
||||
public ListenableFuture<WebSocketSession> doHandshake(WebSocketHandler webSocketHandler,
|
||||
HttpHeaders headers, URI uri) {
|
||||
|
||||
this.webSocketHandler = webSocketHandler;
|
||||
this.headers = headers;
|
||||
this.uri = uri;
|
||||
|
||||
return new ListenableFutureTask<WebSocketSession>(new Callable<WebSocketSession>() {
|
||||
@Override
|
||||
public WebSocketSession call() throws Exception {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ public class StandardWebSocketClientTests {
|
|||
@Test
|
||||
public void localAddress() throws Exception {
|
||||
URI uri = new URI("ws://example.com/abc");
|
||||
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
|
||||
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();
|
||||
|
||||
assertNotNull(session.getLocalAddress());
|
||||
assertEquals(80, session.getLocalAddress().getPort());
|
||||
|
|
@ -75,7 +75,7 @@ public class StandardWebSocketClientTests {
|
|||
@Test
|
||||
public void localAddressWss() throws Exception {
|
||||
URI uri = new URI("wss://example.com/abc");
|
||||
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
|
||||
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();
|
||||
|
||||
assertNotNull(session.getLocalAddress());
|
||||
assertEquals(443, session.getLocalAddress().getPort());
|
||||
|
|
@ -90,7 +90,7 @@ public class StandardWebSocketClientTests {
|
|||
@Test
|
||||
public void remoteAddress() throws Exception {
|
||||
URI uri = new URI("wss://example.com/abc");
|
||||
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
|
||||
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();
|
||||
|
||||
assertNotNull(session.getRemoteAddress());
|
||||
assertEquals("example.com", session.getRemoteAddress().getHostName());
|
||||
|
|
@ -105,7 +105,7 @@ public class StandardWebSocketClientTests {
|
|||
this.headers.setSecWebSocketProtocol(protocols);
|
||||
this.headers.add("foo", "bar");
|
||||
|
||||
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
|
||||
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();
|
||||
|
||||
assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), session.getHandshakeHeaders());
|
||||
}
|
||||
|
|
@ -118,7 +118,7 @@ public class StandardWebSocketClientTests {
|
|||
this.headers.setSecWebSocketProtocol(protocols);
|
||||
this.headers.add("foo", "bar");
|
||||
|
||||
this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
|
||||
this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();
|
||||
|
||||
ArgumentCaptor<Endpoint> arg1 = ArgumentCaptor.forClass(Endpoint.class);
|
||||
ArgumentCaptor<ClientEndpointConfig> arg2 = ArgumentCaptor.forClass(ClientEndpointConfig.class);
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ public class JettyWebSocketClientTests {
|
|||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setSecWebSocketProtocol(Arrays.asList("echo"));
|
||||
|
||||
this.wsSession = this.client.doHandshake(new TextWebSocketHandlerAdapter(), headers, new URI(this.wsUrl));
|
||||
this.wsSession = this.client.doHandshake(new TextWebSocketHandlerAdapter(), headers, new URI(this.wsUrl)).get();
|
||||
|
||||
assertEquals(this.wsUrl, this.wsSession.getUri().toString());
|
||||
assertEquals("echo", this.wsSession.getAcceptedProtocol());
|
||||
|
|
|
|||
|
|
@ -65,8 +65,8 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes
|
|||
@Test
|
||||
public void registerWebSocketHandler() throws Exception {
|
||||
|
||||
WebSocketSession session =
|
||||
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/ws");
|
||||
WebSocketSession session = this.webSocketClient.doHandshake(
|
||||
new WebSocketHandlerAdapter(), getWsBaseUrl() + "/ws").get();
|
||||
|
||||
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
|
||||
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
|
||||
|
|
@ -77,8 +77,8 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes
|
|||
@Test
|
||||
public void registerWebSocketHandlerWithSockJS() throws Exception {
|
||||
|
||||
WebSocketSession session =
|
||||
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/sockjs/websocket");
|
||||
WebSocketSession session = this.webSocketClient.doHandshake(
|
||||
new WebSocketHandlerAdapter(), getWsBaseUrl() + "/sockjs/websocket").get();
|
||||
|
||||
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
|
||||
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
|
||||
|
|
|
|||
Loading…
Reference in New Issue