Simplify use of headers for SockJsClient requests

Before this change, XhrTransport implementations had to be configured
with the headers to use for HTTP requests other than the initial
handshake.

After this change the handshake headers passed to SockJsClient by
default are used for all other HTTP requests related to the SockJS
connection (e.g. info request, xhr send/receive). A property on
SockJsClient allows restricting the headers to use for other HTTP
requests to a subset of the handshake headers.

Issue: SPR-13254
This commit is contained in:
Rossen Stoyanchev 2015-07-28 14:17:57 -04:00
parent 9f557cf930
commit b7bdd724b2
17 changed files with 264 additions and 87 deletions

View File

@ -26,7 +26,6 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.SettableListenableFuture;
@ -61,8 +60,6 @@ public abstract class AbstractXhrTransport implements XhrTransport {
private HttpHeaders requestHeaders = new HttpHeaders();
private HttpHeaders xhrSendRequestHeaders = new HttpHeaders();
@Override
public List<TransportType> getTransportTypes() {
@ -97,17 +94,17 @@ public abstract class AbstractXhrTransport implements XhrTransport {
/**
* Configure headers to be added to every executed HTTP request.
* @param requestHeaders the headers to add to requests
* @deprecated as of 4.2 in favor of {@link SockJsClient#setHttpHeaderNames}.
*/
@Deprecated
public void setRequestHeaders(HttpHeaders requestHeaders) {
this.requestHeaders.clear();
this.xhrSendRequestHeaders.clear();
if (requestHeaders != null) {
this.requestHeaders.putAll(requestHeaders);
this.xhrSendRequestHeaders.putAll(requestHeaders);
this.xhrSendRequestHeaders.setContentType(MediaType.APPLICATION_JSON);
}
}
@Deprecated
public HttpHeaders getRequestHeaders() {
return this.requestHeaders;
}
@ -115,6 +112,7 @@ public abstract class AbstractXhrTransport implements XhrTransport {
// Transport methods
@SuppressWarnings("deprecation")
@Override
public ListenableFuture<WebSocketSession> connect(TransportRequest request, WebSocketHandler handler) {
SettableListenableFuture<WebSocketSession> connectFuture = new SettableListenableFuture<WebSocketSession>();
@ -128,8 +126,8 @@ public abstract class AbstractXhrTransport implements XhrTransport {
}
HttpHeaders handshakeHeaders = new HttpHeaders();
handshakeHeaders.putAll(request.getHandshakeHeaders());
handshakeHeaders.putAll(getRequestHeaders());
handshakeHeaders.putAll(request.getHandshakeHeaders());
connectInternal(request, handler, receiveUrl, handshakeHeaders, session, connectFuture);
return connectFuture;
@ -142,11 +140,17 @@ public abstract class AbstractXhrTransport implements XhrTransport {
// InfoReceiver methods
@Override
public String executeInfoRequest(URI infoUrl) {
@SuppressWarnings("deprecation")
public String executeInfoRequest(URI infoUrl, HttpHeaders headers) {
if (logger.isDebugEnabled()) {
logger.debug("Executing SockJS Info request, url=" + infoUrl);
}
ResponseEntity<String> response = executeInfoRequestInternal(infoUrl);
HttpHeaders infoRequestHeaders = new HttpHeaders();
infoRequestHeaders.putAll(getRequestHeaders());
if (headers != null) {
infoRequestHeaders.putAll(headers);
}
ResponseEntity<String> response = executeInfoRequestInternal(infoUrl, infoRequestHeaders);
if (response.getStatusCode() != HttpStatus.OK) {
if (logger.isErrorEnabled()) {
logger.error("SockJS Info request (url=" + infoUrl + ") failed: " + response);
@ -159,16 +163,16 @@ public abstract class AbstractXhrTransport implements XhrTransport {
return response.getBody();
}
protected abstract ResponseEntity<String> executeInfoRequestInternal(URI infoUrl);
protected abstract ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers);
// XhrTransport methods
@Override
public void executeSendRequest(URI url, TextMessage message) {
public void executeSendRequest(URI url, HttpHeaders headers, TextMessage message) {
if (logger.isTraceEnabled()) {
logger.trace("Starting XHR send, url=" + url);
}
ResponseEntity<String> response = executeSendRequestInternal(url, this.xhrSendRequestHeaders, message);
ResponseEntity<String> response = executeSendRequestInternal(url, headers, message);
if (response.getStatusCode() != HttpStatus.NO_CONTENT) {
if (logger.isErrorEnabled()) {
logger.error("XHR send request (url=" + url + ") failed: " + response);
@ -180,7 +184,8 @@ public abstract class AbstractXhrTransport implements XhrTransport {
}
}
protected abstract ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message);
protected abstract ResponseEntity<String> executeSendRequestInternal(URI url,
HttpHeaders headers, TextMessage message);
@Override

View File

@ -52,6 +52,8 @@ class DefaultTransportRequest implements TransportRequest {
private final HttpHeaders handshakeHeaders;
private final HttpHeaders httpRequestHeaders;
private final Transport transport;
private final TransportType serverTransportType;
@ -69,7 +71,8 @@ class DefaultTransportRequest implements TransportRequest {
private DefaultTransportRequest fallbackRequest;
public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, HttpHeaders handshakeHeaders,
public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo,
HttpHeaders handshakeHeaders, HttpHeaders httpRequestHeaders,
Transport transport, TransportType serverTransportType, SockJsMessageCodec codec) {
Assert.notNull(sockJsUrlInfo, "'sockJsUrlInfo' is required");
@ -78,6 +81,7 @@ class DefaultTransportRequest implements TransportRequest {
Assert.notNull(codec, "'codec' is required");
this.sockJsUrlInfo = sockJsUrlInfo;
this.handshakeHeaders = (handshakeHeaders != null ? handshakeHeaders : new HttpHeaders());
this.httpRequestHeaders = (httpRequestHeaders != null ? httpRequestHeaders : new HttpHeaders());
this.transport = transport;
this.serverTransportType = serverTransportType;
this.codec = codec;
@ -94,6 +98,11 @@ class DefaultTransportRequest implements TransportRequest {
return this.handshakeHeaders;
}
@Override
public HttpHeaders getHttpRequestHeaders() {
return this.httpRequestHeaders;
}
@Override
public URI getTransportUrl() {
return this.sockJsUrlInfo.getTransportUrl(this.serverTransportType);

View File

@ -17,6 +17,8 @@ package org.springframework.web.socket.sockjs.client;
import java.net.URI;
import org.springframework.http.HttpHeaders;
/**
* A component that can execute the SockJS "Info" request that needs to be
* performed before the SockJS session starts in order to check server endpoint
@ -34,10 +36,11 @@ public interface InfoReceiver {
/**
* Perform an HTTP request to the SockJS "Info" URL.
* and return the resulting JSON response content, or raise an exception.
*
* <p>Note that as of 4.2 this method accepts a {@code headers} parameter.
* @param infoUrl the URL to obtain SockJS server information from
* @param headers the headers to use for the request
* @return the body of the response
*/
String executeInfoRequest(URI infoUrl);
String executeInfoRequest(URI infoUrl, HttpHeaders headers);
}

View File

@ -106,11 +106,12 @@ public class JettyXhrTransport extends AbstractXhrTransport implements XhrTransp
@Override
protected void connectInternal(TransportRequest request, WebSocketHandler handler,
protected void connectInternal(TransportRequest transportRequest, WebSocketHandler handler,
URI url, HttpHeaders handshakeHeaders, XhrClientSockJsSession session,
SettableListenableFuture<WebSocketSession> connectFuture) {
SockJsResponseListener listener = new SockJsResponseListener(url, getRequestHeaders(), session, connectFuture);
HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders();
SockJsResponseListener listener = new SockJsResponseListener(url, httpHeaders, session, connectFuture);
executeReceiveRequest(url, handshakeHeaders, listener);
}
@ -124,8 +125,8 @@ public class JettyXhrTransport extends AbstractXhrTransport implements XhrTransp
}
@Override
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
return executeRequest(infoUrl, HttpMethod.GET, getRequestHeaders(), null);
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
return executeRequest(infoUrl, HttpMethod.GET, headers, null);
}
@Override

View File

@ -94,15 +94,16 @@ public class RestTemplateXhrTransport extends AbstractXhrTransport implements Xh
@Override
protected void connectInternal(final TransportRequest request, final WebSocketHandler handler,
protected void connectInternal(final TransportRequest transportRequest, final WebSocketHandler handler,
final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session,
final SettableListenableFuture<WebSocketSession> connectFuture) {
getTaskExecutor().execute(new Runnable() {
@Override
public void run() {
HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders();
XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders);
XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(getRequestHeaders());
XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(httpHeaders);
XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session);
while (true) {
if (session.isDisconnected()) {
@ -132,8 +133,8 @@ public class RestTemplateXhrTransport extends AbstractXhrTransport implements Xh
}
@Override
public ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
RequestCallback requestCallback = new XhrRequestCallback(getRequestHeaders());
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
RequestCallback requestCallback = new XhrRequestCallback(headers);
return this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textResponseExtractor);
}

View File

@ -78,6 +78,8 @@ public class SockJsClient implements WebSocketClient, Lifecycle {
private final List<Transport> transports;
private String[] httpHeaderNames;
private InfoReceiver infoReceiver;
private SockJsMessageCodec messageCodec;
@ -116,6 +118,30 @@ public class SockJsClient implements WebSocketClient, Lifecycle {
}
/**
* The names of HTTP headers that should be copied from the handshake headers
* of each call to {@link SockJsClient#doHandshake(WebSocketHandler, WebSocketHttpHeaders, URI)}
* and also used with other HTTP requests issued as part of that SockJS
* connection, e.g. the initial info request, XHR send or receive requests.
*
* <p>By default if this property is not set, all handshake headers are also
* used for other HTTP requests. Set it if you want only a subset of handshake
* headers (e.g. auth headers) to be used for other HTTP requests.
*
* @param httpHeaderNames HTTP header names
*/
public void setHttpHeaderNames(String... httpHeaderNames) {
this.httpHeaderNames = httpHeaderNames;
}
/**
* The configured HTTP header names to be copied from the handshake
* headers and also included in other HTTP requests.
*/
public String[] getHttpHeaderNames() {
return this.httpHeaderNames;
}
/**
* Configure the {@code InfoReceiver} to use to perform the SockJS "Info"
* request before the SockJS session starts.
@ -225,7 +251,7 @@ public class SockJsClient implements WebSocketClient, Lifecycle {
SettableListenableFuture<WebSocketSession> connectFuture = new SettableListenableFuture<WebSocketSession>();
try {
SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url);
ServerInfo serverInfo = getServerInfo(sockJsUrlInfo);
ServerInfo serverInfo = getServerInfo(sockJsUrlInfo, getHttpRequestHeaders(headers));
createRequest(sockJsUrlInfo, headers, serverInfo).connect(handler, connectFuture);
}
catch (Throwable exception) {
@ -237,12 +263,27 @@ public class SockJsClient implements WebSocketClient, Lifecycle {
return connectFuture;
}
private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo) {
private HttpHeaders getHttpRequestHeaders(HttpHeaders webSocketHttpHeaders) {
if (getHttpHeaderNames() == null) {
return webSocketHttpHeaders;
}
else {
HttpHeaders httpHeaders = new HttpHeaders();
for (String name : getHttpHeaderNames()) {
if (webSocketHttpHeaders.containsKey(name)) {
httpHeaders.put(name, webSocketHttpHeaders.get(name));
}
}
return httpHeaders;
}
}
private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo, HttpHeaders headers) {
URI infoUrl = sockJsUrlInfo.getInfoUrl();
ServerInfo info = this.serverInfoCache.get(infoUrl);
if (info == null) {
long start = System.currentTimeMillis();
String response = this.infoReceiver.executeInfoRequest(infoUrl);
String response = this.infoReceiver.executeInfoRequest(infoUrl, headers);
long infoRequestTime = System.currentTimeMillis() - start;
info = new ServerInfo(response, infoRequestTime);
this.serverInfoCache.put(infoUrl, info);
@ -255,7 +296,8 @@ public class SockJsClient implements WebSocketClient, Lifecycle {
for (Transport transport : this.transports) {
for (TransportType type : transport.getTransportTypes()) {
if (serverInfo.isWebSocketEnabled() || !TransportType.WEBSOCKET.equals(type)) {
requests.add(new DefaultTransportRequest(urlInfo, headers, transport, type, getMessageCodec()));
requests.add(new DefaultTransportRequest(urlInfo, headers, getHttpRequestHeaders(headers),
transport, type, getMessageCodec()));
}
}
}

View File

@ -47,6 +47,13 @@ public interface TransportRequest {
*/
HttpHeaders getHandshakeHeaders();
/**
* Return the headers to add to all other HTTP requests besides the handshake
* request such XHR receive and send requests.
* @since 4.2
*/
HttpHeaders getHttpRequestHeaders();
/**
* Return the transport URL for the given transport.
* For an {@link XhrTransport} this is the URL for receiving messages.

View File

@ -134,11 +134,11 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra
HttpHeaders handshakeHeaders, XhrClientSockJsSession session,
SettableListenableFuture<WebSocketSession> connectFuture) {
executeReceiveRequest(receiveUrl, handshakeHeaders, session, connectFuture);
executeReceiveRequest(request, receiveUrl, handshakeHeaders, session, connectFuture);
}
private void executeReceiveRequest(final URI url, final HttpHeaders headers,
final XhrClientSockJsSession session,
private void executeReceiveRequest(final TransportRequest transportRequest,
final URI url, final HttpHeaders headers, final XhrClientSockJsSession session,
final SettableListenableFuture<WebSocketSession> connectFuture) {
if (logger.isTraceEnabled()) {
@ -154,8 +154,9 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra
HttpString headerName = HttpString.tryFromString(HttpHeaders.HOST);
request.getRequestHeaders().add(headerName, url.getHost());
addHttpHeaders(request, headers);
connection.sendRequest(request, createReceiveCallback(url,
getRequestHeaders(), session, connectFuture));
HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders();
connection.sendRequest(request, createReceiveCallback(transportRequest,
url, httpHeaders, session, connectFuture));
}
@Override
@ -175,8 +176,8 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra
}
}
private ClientCallback<ClientExchange> createReceiveCallback(final URI url, final HttpHeaders headers,
final XhrClientSockJsSession sockJsSession,
private ClientCallback<ClientExchange> createReceiveCallback(final TransportRequest transportRequest,
final URI url, final HttpHeaders headers, final XhrClientSockJsSession sockJsSession,
final SettableListenableFuture<WebSocketSession> connectFuture) {
return new ClientCallback<ClientExchange>() {
@ -194,8 +195,9 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra
onFailure(new HttpServerErrorException(status, "Unexpected XHR receive status"));
}
else {
SockJsResponseListener listener = new SockJsResponseListener(result.getConnection(),
url, headers, sockJsSession, connectFuture);
SockJsResponseListener listener = new SockJsResponseListener(
transportRequest, result.getConnection(), url, headers,
sockJsSession, connectFuture);
listener.setup(result.getResponseChannel());
}
if (logger.isTraceEnabled()) {
@ -254,8 +256,8 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra
}
@Override
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
return executeRequest(infoUrl, Methods.GET, getRequestHeaders(), null);
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
return executeRequest(infoUrl, Methods.GET, headers, null);
}
@Override
@ -360,6 +362,8 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra
private class SockJsResponseListener implements ChannelListener<StreamSourceChannel> {
private final TransportRequest request;
private final ClientConnection connection;
private final URI url;
@ -372,10 +376,12 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra
private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
public SockJsResponseListener(ClientConnection connection, URI url,
public SockJsResponseListener(TransportRequest request, ClientConnection connection, URI url,
HttpHeaders headers, XhrClientSockJsSession sockJsSession,
SettableListenableFuture<WebSocketSession> connectFuture) {
this.request = request;
this.connection = connection;
this.url = url;
this.headers = headers;
@ -455,7 +461,7 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra
logger.trace("XHR receive request completed.");
}
IoUtils.safeClose(this.connection);
executeReceiveRequest(this.url, this.headers, this.session, this.connectFuture);
executeReceiveRequest(this.request, this.url, this.headers, this.session, this.connectFuture);
}
public void onFailure(Throwable failure) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,6 +20,8 @@ import java.net.InetSocketAddress;
import java.net.URI;
import java.util.List;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.CloseStatus;
@ -39,10 +41,14 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
*/
public class XhrClientSockJsSession extends AbstractClientSockJsSession {
private final URI sendUrl;
private final XhrTransport transport;
private HttpHeaders headers;
private HttpHeaders sendHeaders;
private final URI sendUrl;
private int textMessageSizeLimit = -1;
private int binaryMessageSizeLimit = -1;
@ -53,11 +59,21 @@ public class XhrClientSockJsSession extends AbstractClientSockJsSession {
super(request, handler, connectFuture);
Assert.notNull(transport, "'restTemplate' is required");
this.sendUrl = request.getSockJsUrlInfo().getTransportUrl(TransportType.XHR_SEND);
this.transport = transport;
this.headers = request.getHttpRequestHeaders();
this.sendHeaders = new HttpHeaders();
if (this.headers != null) {
this.sendHeaders.putAll(this.headers);
}
this.sendHeaders.setContentType(MediaType.APPLICATION_JSON);
this.sendUrl = request.getSockJsUrlInfo().getTransportUrl(TransportType.XHR_SEND);
}
public HttpHeaders getHeaders() {
return this.headers;
}
@Override
public InetSocketAddress getLocalAddress() {
return null;
@ -100,7 +116,7 @@ public class XhrClientSockJsSession extends AbstractClientSockJsSession {
@Override
protected void sendInternal(TextMessage message) {
this.transport.executeSendRequest(this.sendUrl, message);
this.transport.executeSendRequest(this.sendUrl, this.sendHeaders, message);
}
@Override

View File

@ -17,14 +17,14 @@ package org.springframework.web.socket.sockjs.client;
import java.net.URI;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.TextMessage;
/**
* A SockJS {@link Transport} that uses HTTP requests to simulate a WebSocket
* interaction. The {@code connect} method of the base {@code Transport} interface
* is used to receive messages from the server while the
* {@link #executeSendRequest(java.net.URI, org.springframework.web.socket.TextMessage)
* executeSendRequest(URI, TextMessage)} method here is used to send messages.
* {@link #executeSendRequest} method here is used to send messages.
*
* @author Rossen Stoyanchev
* @since 4.1
@ -35,7 +35,6 @@ public interface XhrTransport extends Transport, InfoReceiver {
* An {@code XhrTransport} supports both the "xhr_streaming" and "xhr" SockJS
* server transports. From a client perspective there is no implementation
* difference.
*
* <p>By default an {@code XhrTransport} will be used with "xhr_streaming"
* first and then with "xhr", if the streaming fails to connect. In some
* cases it may be useful to suppress streaming so that only "xhr" is used.
@ -44,9 +43,10 @@ public interface XhrTransport extends Transport, InfoReceiver {
/**
* Execute a request to send the message to the server.
* <p>Note that as of 4.2 this method accepts a {@code headers} parameter.
* @param transportUrl the URL for sending messages.
* @param message the message to send
*/
void executeSendRequest(URI transportUrl, TextMessage message);
void executeSendRequest(URI transportUrl, HttpHeaders headers, TextMessage message);
}

View File

@ -49,6 +49,8 @@ import org.junit.rules.TestName;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.tests.Assume;
import org.springframework.tests.TestGroup;
@ -100,7 +102,7 @@ public abstract class AbstractSockJsIntegrationTests {
@BeforeClass
public static void performanceTestGroupAssumption() throws Exception {
Assume.group(TestGroup.PERFORMANCE);
// Assume.group(TestGroup.PERFORMANCE);
}
@ -164,19 +166,36 @@ public abstract class AbstractSockJsIntegrationTests {
@Test
public void echoWebSocket() throws Exception {
testEcho(100, createWebSocketTransport());
testEcho(100, createWebSocketTransport(), null);
}
@Test
public void echoXhrStreaming() throws Exception {
testEcho(100, createXhrTransport());
testEcho(100, createXhrTransport(), null);
}
@Test
public void echoXhr() throws Exception {
AbstractXhrTransport xhrTransport = createXhrTransport();
xhrTransport.setXhrStreamingDisabled(true);
testEcho(100, xhrTransport);
testEcho(100, xhrTransport, null);
}
// SPR-13254
@Test
public void echoXhrWithHeaders() throws Exception {
AbstractXhrTransport xhrTransport = createXhrTransport();
xhrTransport.setXhrStreamingDisabled(true);
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.add("auth", "123");
testEcho(10, xhrTransport, headers);
for (Map.Entry<String, HttpHeaders> entry : this.testFilter.requests.entrySet()) {
HttpHeaders httpHeaders = entry.getValue();
assertEquals("No auth header for: " + entry.getKey(), "123", httpHeaders.getFirst("auth"));
}
}
@Test
@ -246,14 +265,15 @@ public abstract class AbstractSockJsIntegrationTests {
}
private void testEcho(int messageCount, Transport transport) throws Exception {
private void testEcho(int messageCount, Transport transport, WebSocketHttpHeaders headers) throws Exception {
List<TextMessage> messages = new ArrayList<>();
for (int i = 0; i < messageCount; i++) {
messages.add(new TextMessage("m" + i));
}
TestClientHandler handler = new TestClientHandler();
initSockJsClient(transport);
WebSocketSession session = this.sockJsClient.doHandshake(handler, this.baseUrl + "/echo").get();
URI url = new URI(this.baseUrl + "/echo");
WebSocketSession session = this.sockJsClient.doHandshake(handler, headers, url).get();
for (TextMessage message : messages) {
session.sendMessage(message);
}
@ -386,7 +406,7 @@ public abstract class AbstractSockJsIntegrationTests {
private static class TestFilter implements Filter {
private final List<ServletRequest> requests = new ArrayList<>();
private final Map<String, HttpHeaders> requests = new HashMap<>();
private final Map<String, Long> sleepDelayMap = new HashMap<>();
@ -397,10 +417,13 @@ public abstract class AbstractSockJsIntegrationTests {
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
this.requests.add(request);
HttpServletRequest httpRequest = (HttpServletRequest) request;
String uri = httpRequest.getRequestURI();
HttpHeaders headers = new ServletServerHttpRequest(httpRequest).getHeaders();
this.requests.put(uri, headers);
for (String suffix : this.sleepDelayMap.keySet()) {
if (((HttpServletRequest) request).getRequestURI().endsWith(suffix)) {
if ((httpRequest).getRequestURI().endsWith(suffix)) {
try {
Thread.sleep(this.sleepDelayMap.get(suffix));
break;
@ -411,7 +434,7 @@ public abstract class AbstractSockJsIntegrationTests {
}
}
for (String suffix : this.sendErrorMap.keySet()) {
if (((HttpServletRequest) request).getRequestURI().endsWith(suffix)) {
if ((httpRequest).getRequestURI().endsWith(suffix)) {
((HttpServletResponse) response).sendError(this.sendErrorMap.get(suffix));
return;
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -64,7 +64,7 @@ public class ClientSockJsSessionTests {
public void setup() throws Exception {
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
Transport transport = mock(Transport.class);
TransportRequest request = new DefaultTransportRequest(urlInfo, null, transport, TransportType.XHR, CODEC);
TransportRequest request = new DefaultTransportRequest(urlInfo, null, null, transport, TransportType.XHR, CODEC);
this.handler = mock(WebSocketHandler.class);
this.connectFuture = new SettableListenableFuture<>();
this.session = new TestClientSockJsSession(request, this.handler, this.connectFuture);

View File

@ -127,7 +127,7 @@ public class DefaultTransportRequestTests {
protected DefaultTransportRequest createTransportRequest(Transport transport, TransportType type) throws Exception {
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
return new DefaultTransportRequest(urlInfo, new HttpHeaders(), transport, type, CODEC);
return new DefaultTransportRequest(urlInfo, new HttpHeaders(), new HttpHeaders(), transport, type, CODEC);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -182,7 +182,8 @@ public class RestTemplateXhrTransportTests {
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
HttpHeaders headers = new HttpHeaders();
headers.add("h-foo", "h-bar");
TransportRequest request = new DefaultTransportRequest(urlInfo, headers, transport, TransportType.XHR, CODEC);
TransportRequest request = new DefaultTransportRequest(urlInfo, headers, headers,
transport, TransportType.XHR, CODEC);
return transport.connect(request, this.webSocketHandler);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,22 +16,34 @@
package org.springframework.web.socket.sockjs.client;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.client.TestTransport.XhrTestTransport;
import static org.junit.Assert.*;
import static org.mockito.BDDMockito.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.BDDMockito.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.times;
import static org.mockito.BDDMockito.verify;
import static org.mockito.BDDMockito.verifyNoMoreInteractions;
import static org.mockito.BDDMockito.when;
/**
* Unit tests for {@link org.springframework.web.socket.sockjs.client.SockJsClient}.
@ -102,11 +114,51 @@ public class SockJsClientTests {
assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr"));
}
// SPR-13254
@Test
public void connectWithHandshakeHeaders() throws Exception {
ArgumentCaptor<HttpHeaders> headersCaptor = setupInfoRequest(false);
this.xhrTransport.setStreamingDisabled(true);
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.set("foo", "bar");
headers.set("auth", "123");
this.sockJsClient.doHandshake(handler, headers, new URI(URL)).addCallback(this.connectCallback);
HttpHeaders httpHeaders = headersCaptor.getValue();
assertEquals(2, httpHeaders.size());
assertEquals("bar", httpHeaders.getFirst("foo"));
assertEquals("123", httpHeaders.getFirst("auth"));
httpHeaders = this.xhrTransport.getRequest().getHttpRequestHeaders();
assertEquals(2, httpHeaders.size());
assertEquals("bar", httpHeaders.getFirst("foo"));
assertEquals("123", httpHeaders.getFirst("auth"));
}
@Test
public void connectAndUseSubsetOfHandshakeHeadersForHttpRequests() throws Exception {
ArgumentCaptor<HttpHeaders> headersCaptor = setupInfoRequest(false);
this.xhrTransport.setStreamingDisabled(true);
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.set("foo", "bar");
headers.set("auth", "123");
this.sockJsClient.setHttpHeaderNames("auth");
this.sockJsClient.doHandshake(handler, headers, new URI(URL)).addCallback(this.connectCallback);
assertEquals(1, headersCaptor.getValue().size());
assertEquals("123", headersCaptor.getValue().getFirst("auth"));
assertEquals(1, this.xhrTransport.getRequest().getHttpRequestHeaders().size());
assertEquals("123", this.xhrTransport.getRequest().getHttpRequestHeaders().getFirst("auth"));
}
@Test
public void connectSockJsInfo() throws Exception {
setupInfoRequest(true);
this.sockJsClient.doHandshake(handler, URL);
verify(this.infoReceiver, times(1)).executeInfoRequest(any());
verify(this.infoReceiver, times(1)).executeInfoRequest(any(), any());
}
@Test
@ -115,22 +167,27 @@ public class SockJsClientTests {
this.sockJsClient.doHandshake(handler, URL);
this.sockJsClient.doHandshake(handler, URL);
this.sockJsClient.doHandshake(handler, URL);
verify(this.infoReceiver, times(1)).executeInfoRequest(any());
verify(this.infoReceiver, times(1)).executeInfoRequest(any(), any());
}
@Test
public void connectInfoRequestFailure() throws URISyntaxException {
HttpServerErrorException exception = new HttpServerErrorException(HttpStatus.SERVICE_UNAVAILABLE);
given(this.infoReceiver.executeInfoRequest(any())).willThrow(exception);
given(this.infoReceiver.executeInfoRequest(any(), any())).willThrow(exception);
this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback);
verify(this.connectCallback).onFailure(exception);
assertFalse(this.webSocketTransport.invoked());
assertFalse(this.xhrTransport.invoked());
}
private void setupInfoRequest(boolean webSocketEnabled) {
given(this.infoReceiver.executeInfoRequest(any())).willReturn("{\"entropy\":123," +
"\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":" + webSocketEnabled + "}");
private ArgumentCaptor<HttpHeaders> setupInfoRequest(boolean webSocketEnabled) {
ArgumentCaptor<HttpHeaders> headersCaptor = ArgumentCaptor.forClass(HttpHeaders.class);
when(this.infoReceiver.executeInfoRequest(any(), headersCaptor.capture())).thenReturn(
"{\"entropy\":123," +
"\"origins\":[\"*:*\"]," +
"\"cookie_needed\":true," +
"\"websocket\":" + webSocketEnabled + "}");
return headersCaptor;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,9 +18,12 @@ package org.springframework.web.socket.sockjs.client;
import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpHeaders;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.socket.TextMessage;
@ -28,7 +31,8 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.transport.TransportType;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Test SockJS Transport.
@ -51,7 +55,7 @@ class TestTransport implements Transport {
@Override
public List<TransportType> getTransportTypes() {
return Arrays.asList(TransportType.WEBSOCKET);
return Collections.singletonList(TransportType.WEBSOCKET);
}
public TransportRequest getRequest() {
@ -95,7 +99,7 @@ class TestTransport implements Transport {
@Override
public List<TransportType> getTransportTypes() {
return (isXhrStreamingDisabled() ?
Arrays.asList(TransportType.XHR) :
Collections.singletonList(TransportType.XHR) :
Arrays.asList(TransportType.XHR_STREAMING, TransportType.XHR));
}
@ -109,11 +113,11 @@ class TestTransport implements Transport {
}
@Override
public void executeSendRequest(URI transportUrl, TextMessage message) {
public void executeSendRequest(URI transportUrl, HttpHeaders headers, TextMessage message) {
}
@Override
public String executeInfoRequest(URI infoUrl) {
public String executeInfoRequest(URI infoUrl, HttpHeaders headers) {
return null;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -46,25 +46,25 @@ public class XhrTransportTests {
public void infoResponse() throws Exception {
TestXhrTransport transport = new TestXhrTransport();
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK);
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info")));
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"), null));
}
@Test(expected = HttpServerErrorException.class)
public void infoResponseError() throws Exception {
TestXhrTransport transport = new TestXhrTransport();
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST);
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info")));
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"), null));
}
@Test
public void sendMessage() throws Exception {
HttpHeaders requestHeaders = new HttpHeaders();
requestHeaders.set("foo", "bar");
requestHeaders.setContentType(MediaType.APPLICATION_JSON);
TestXhrTransport transport = new TestXhrTransport();
transport.setRequestHeaders(requestHeaders);
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT);
URI url = new URI("http://example.com");
transport.executeSendRequest(url, new TextMessage("payload"));
transport.executeSendRequest(url, requestHeaders, new TextMessage("payload"));
assertEquals(2, transport.actualSendRequestHeaders.size());
assertEquals("bar", transport.actualSendRequestHeaders.getFirst("foo"));
assertEquals(MediaType.APPLICATION_JSON, transport.actualSendRequestHeaders.getContentType());
@ -75,9 +75,10 @@ public class XhrTransportTests {
TestXhrTransport transport = new TestXhrTransport();
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST);
URI url = new URI("http://example.com");
transport.executeSendRequest(url, new TextMessage("payload"));
transport.executeSendRequest(url, null, new TextMessage("payload"));
}
@SuppressWarnings("deprecation")
@Test
public void connect() throws Exception {
HttpHeaders handshakeHeaders = new HttpHeaders();
@ -101,6 +102,7 @@ public class XhrTransportTests {
verify(request).addTimeoutTask(captor.capture());
verify(request).getTransportUrl();
verify(request).getHandshakeHeaders();
verify(request).getHttpRequestHeaders();
verifyNoMoreInteractions(request);
assertEquals(2, transport.actualHandshakeHeaders.size());
@ -127,7 +129,7 @@ public class XhrTransportTests {
@Override
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
return this.infoResponseToReturn;
}