From b7bdd724b2eab6959eade7868cfb1ecf21aa118b Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 28 Jul 2015 14:17:57 -0400 Subject: [PATCH] Simplify use of headers for SockJsClient requests Before this change, XhrTransport implementations had to be configured with the headers to use for HTTP requests other than the initial handshake. After this change the handshake headers passed to SockJsClient by default are used for all other HTTP requests related to the SockJS connection (e.g. info request, xhr send/receive). A property on SockJsClient allows restricting the headers to use for other HTTP requests to a subset of the handshake headers. Issue: SPR-13254 --- .../sockjs/client/AbstractXhrTransport.java | 31 ++++---- .../client/DefaultTransportRequest.java | 11 ++- .../socket/sockjs/client/InfoReceiver.java | 7 +- .../sockjs/client/JettyXhrTransport.java | 9 ++- .../client/RestTemplateXhrTransport.java | 9 ++- .../socket/sockjs/client/SockJsClient.java | 50 ++++++++++++- .../sockjs/client/TransportRequest.java | 7 ++ .../sockjs/client/UndertowXhrTransport.java | 32 ++++---- .../sockjs/client/XhrClientSockJsSession.java | 26 +++++-- .../socket/sockjs/client/XhrTransport.java | 8 +- .../AbstractSockJsIntegrationTests.java | 43 ++++++++--- .../client/ClientSockJsSessionTests.java | 4 +- .../client/DefaultTransportRequestTests.java | 2 +- .../client/RestTemplateXhrTransportTests.java | 5 +- .../sockjs/client/SockJsClientTests.java | 75 ++++++++++++++++--- .../socket/sockjs/client/TestTransport.java | 16 ++-- .../sockjs/client/XhrTransportTests.java | 16 ++-- 17 files changed, 264 insertions(+), 87 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java index a8e12ac3b57..ce77aeb2798 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java @@ -26,7 +26,6 @@ import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.util.concurrent.SettableListenableFuture; @@ -61,8 +60,6 @@ public abstract class AbstractXhrTransport implements XhrTransport { private HttpHeaders requestHeaders = new HttpHeaders(); - private HttpHeaders xhrSendRequestHeaders = new HttpHeaders(); - @Override public List getTransportTypes() { @@ -97,17 +94,17 @@ public abstract class AbstractXhrTransport implements XhrTransport { /** * Configure headers to be added to every executed HTTP request. * @param requestHeaders the headers to add to requests + * @deprecated as of 4.2 in favor of {@link SockJsClient#setHttpHeaderNames}. */ + @Deprecated public void setRequestHeaders(HttpHeaders requestHeaders) { this.requestHeaders.clear(); - this.xhrSendRequestHeaders.clear(); if (requestHeaders != null) { this.requestHeaders.putAll(requestHeaders); - this.xhrSendRequestHeaders.putAll(requestHeaders); - this.xhrSendRequestHeaders.setContentType(MediaType.APPLICATION_JSON); } } + @Deprecated public HttpHeaders getRequestHeaders() { return this.requestHeaders; } @@ -115,6 +112,7 @@ public abstract class AbstractXhrTransport implements XhrTransport { // Transport methods + @SuppressWarnings("deprecation") @Override public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { SettableListenableFuture connectFuture = new SettableListenableFuture(); @@ -128,8 +126,8 @@ public abstract class AbstractXhrTransport implements XhrTransport { } HttpHeaders handshakeHeaders = new HttpHeaders(); - handshakeHeaders.putAll(request.getHandshakeHeaders()); handshakeHeaders.putAll(getRequestHeaders()); + handshakeHeaders.putAll(request.getHandshakeHeaders()); connectInternal(request, handler, receiveUrl, handshakeHeaders, session, connectFuture); return connectFuture; @@ -142,11 +140,17 @@ public abstract class AbstractXhrTransport implements XhrTransport { // InfoReceiver methods @Override - public String executeInfoRequest(URI infoUrl) { + @SuppressWarnings("deprecation") + public String executeInfoRequest(URI infoUrl, HttpHeaders headers) { if (logger.isDebugEnabled()) { logger.debug("Executing SockJS Info request, url=" + infoUrl); } - ResponseEntity response = executeInfoRequestInternal(infoUrl); + HttpHeaders infoRequestHeaders = new HttpHeaders(); + infoRequestHeaders.putAll(getRequestHeaders()); + if (headers != null) { + infoRequestHeaders.putAll(headers); + } + ResponseEntity response = executeInfoRequestInternal(infoUrl, infoRequestHeaders); if (response.getStatusCode() != HttpStatus.OK) { if (logger.isErrorEnabled()) { logger.error("SockJS Info request (url=" + infoUrl + ") failed: " + response); @@ -159,16 +163,16 @@ public abstract class AbstractXhrTransport implements XhrTransport { return response.getBody(); } - protected abstract ResponseEntity executeInfoRequestInternal(URI infoUrl); + protected abstract ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers); // XhrTransport methods @Override - public void executeSendRequest(URI url, TextMessage message) { + public void executeSendRequest(URI url, HttpHeaders headers, TextMessage message) { if (logger.isTraceEnabled()) { logger.trace("Starting XHR send, url=" + url); } - ResponseEntity response = executeSendRequestInternal(url, this.xhrSendRequestHeaders, message); + ResponseEntity response = executeSendRequestInternal(url, headers, message); if (response.getStatusCode() != HttpStatus.NO_CONTENT) { if (logger.isErrorEnabled()) { logger.error("XHR send request (url=" + url + ") failed: " + response); @@ -180,7 +184,8 @@ public abstract class AbstractXhrTransport implements XhrTransport { } } - protected abstract ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message); + protected abstract ResponseEntity executeSendRequestInternal(URI url, + HttpHeaders headers, TextMessage message); @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java index bf1eacdcd78..06563ad6339 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java @@ -52,6 +52,8 @@ class DefaultTransportRequest implements TransportRequest { private final HttpHeaders handshakeHeaders; + private final HttpHeaders httpRequestHeaders; + private final Transport transport; private final TransportType serverTransportType; @@ -69,7 +71,8 @@ class DefaultTransportRequest implements TransportRequest { private DefaultTransportRequest fallbackRequest; - public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, HttpHeaders handshakeHeaders, + public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, + HttpHeaders handshakeHeaders, HttpHeaders httpRequestHeaders, Transport transport, TransportType serverTransportType, SockJsMessageCodec codec) { Assert.notNull(sockJsUrlInfo, "'sockJsUrlInfo' is required"); @@ -78,6 +81,7 @@ class DefaultTransportRequest implements TransportRequest { Assert.notNull(codec, "'codec' is required"); this.sockJsUrlInfo = sockJsUrlInfo; this.handshakeHeaders = (handshakeHeaders != null ? handshakeHeaders : new HttpHeaders()); + this.httpRequestHeaders = (httpRequestHeaders != null ? httpRequestHeaders : new HttpHeaders()); this.transport = transport; this.serverTransportType = serverTransportType; this.codec = codec; @@ -94,6 +98,11 @@ class DefaultTransportRequest implements TransportRequest { return this.handshakeHeaders; } + @Override + public HttpHeaders getHttpRequestHeaders() { + return this.httpRequestHeaders; + } + @Override public URI getTransportUrl() { return this.sockJsUrlInfo.getTransportUrl(this.serverTransportType); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java index e921c97d720..b039c4dcca4 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.sockjs.client; import java.net.URI; +import org.springframework.http.HttpHeaders; + /** * A component that can execute the SockJS "Info" request that needs to be * performed before the SockJS session starts in order to check server endpoint @@ -34,10 +36,11 @@ public interface InfoReceiver { /** * Perform an HTTP request to the SockJS "Info" URL. * and return the resulting JSON response content, or raise an exception. - * + *

Note that as of 4.2 this method accepts a {@code headers} parameter. * @param infoUrl the URL to obtain SockJS server information from + * @param headers the headers to use for the request * @return the body of the response */ - String executeInfoRequest(URI infoUrl); + String executeInfoRequest(URI infoUrl, HttpHeaders headers); } \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java index 969c71e2781..89088fbdbef 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java @@ -106,11 +106,12 @@ public class JettyXhrTransport extends AbstractXhrTransport implements XhrTransp @Override - protected void connectInternal(TransportRequest request, WebSocketHandler handler, + protected void connectInternal(TransportRequest transportRequest, WebSocketHandler handler, URI url, HttpHeaders handshakeHeaders, XhrClientSockJsSession session, SettableListenableFuture connectFuture) { - SockJsResponseListener listener = new SockJsResponseListener(url, getRequestHeaders(), session, connectFuture); + HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders(); + SockJsResponseListener listener = new SockJsResponseListener(url, httpHeaders, session, connectFuture); executeReceiveRequest(url, handshakeHeaders, listener); } @@ -124,8 +125,8 @@ public class JettyXhrTransport extends AbstractXhrTransport implements XhrTransp } @Override - protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { - return executeRequest(infoUrl, HttpMethod.GET, getRequestHeaders(), null); + protected ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { + return executeRequest(infoUrl, HttpMethod.GET, headers, null); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java index 5d6d25cf3cd..62f913cc149 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java @@ -94,15 +94,16 @@ public class RestTemplateXhrTransport extends AbstractXhrTransport implements Xh @Override - protected void connectInternal(final TransportRequest request, final WebSocketHandler handler, + protected void connectInternal(final TransportRequest transportRequest, final WebSocketHandler handler, final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session, final SettableListenableFuture connectFuture) { getTaskExecutor().execute(new Runnable() { @Override public void run() { + HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders(); XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders); - XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(getRequestHeaders()); + XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(httpHeaders); XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session); while (true) { if (session.isDisconnected()) { @@ -132,8 +133,8 @@ public class RestTemplateXhrTransport extends AbstractXhrTransport implements Xh } @Override - public ResponseEntity executeInfoRequestInternal(URI infoUrl) { - RequestCallback requestCallback = new XhrRequestCallback(getRequestHeaders()); + protected ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { + RequestCallback requestCallback = new XhrRequestCallback(headers); return this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textResponseExtractor); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java index e73c912d417..2d34c44b489 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java @@ -78,6 +78,8 @@ public class SockJsClient implements WebSocketClient, Lifecycle { private final List transports; + private String[] httpHeaderNames; + private InfoReceiver infoReceiver; private SockJsMessageCodec messageCodec; @@ -116,6 +118,30 @@ public class SockJsClient implements WebSocketClient, Lifecycle { } + /** + * The names of HTTP headers that should be copied from the handshake headers + * of each call to {@link SockJsClient#doHandshake(WebSocketHandler, WebSocketHttpHeaders, URI)} + * and also used with other HTTP requests issued as part of that SockJS + * connection, e.g. the initial info request, XHR send or receive requests. + * + *

By default if this property is not set, all handshake headers are also + * used for other HTTP requests. Set it if you want only a subset of handshake + * headers (e.g. auth headers) to be used for other HTTP requests. + * + * @param httpHeaderNames HTTP header names + */ + public void setHttpHeaderNames(String... httpHeaderNames) { + this.httpHeaderNames = httpHeaderNames; + } + + /** + * The configured HTTP header names to be copied from the handshake + * headers and also included in other HTTP requests. + */ + public String[] getHttpHeaderNames() { + return this.httpHeaderNames; + } + /** * Configure the {@code InfoReceiver} to use to perform the SockJS "Info" * request before the SockJS session starts. @@ -225,7 +251,7 @@ public class SockJsClient implements WebSocketClient, Lifecycle { SettableListenableFuture connectFuture = new SettableListenableFuture(); try { SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url); - ServerInfo serverInfo = getServerInfo(sockJsUrlInfo); + ServerInfo serverInfo = getServerInfo(sockJsUrlInfo, getHttpRequestHeaders(headers)); createRequest(sockJsUrlInfo, headers, serverInfo).connect(handler, connectFuture); } catch (Throwable exception) { @@ -237,12 +263,27 @@ public class SockJsClient implements WebSocketClient, Lifecycle { return connectFuture; } - private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo) { + private HttpHeaders getHttpRequestHeaders(HttpHeaders webSocketHttpHeaders) { + if (getHttpHeaderNames() == null) { + return webSocketHttpHeaders; + } + else { + HttpHeaders httpHeaders = new HttpHeaders(); + for (String name : getHttpHeaderNames()) { + if (webSocketHttpHeaders.containsKey(name)) { + httpHeaders.put(name, webSocketHttpHeaders.get(name)); + } + } + return httpHeaders; + } + } + + private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo, HttpHeaders headers) { URI infoUrl = sockJsUrlInfo.getInfoUrl(); ServerInfo info = this.serverInfoCache.get(infoUrl); if (info == null) { long start = System.currentTimeMillis(); - String response = this.infoReceiver.executeInfoRequest(infoUrl); + String response = this.infoReceiver.executeInfoRequest(infoUrl, headers); long infoRequestTime = System.currentTimeMillis() - start; info = new ServerInfo(response, infoRequestTime); this.serverInfoCache.put(infoUrl, info); @@ -255,7 +296,8 @@ public class SockJsClient implements WebSocketClient, Lifecycle { for (Transport transport : this.transports) { for (TransportType type : transport.getTransportTypes()) { if (serverInfo.isWebSocketEnabled() || !TransportType.WEBSOCKET.equals(type)) { - requests.add(new DefaultTransportRequest(urlInfo, headers, transport, type, getMessageCodec())); + requests.add(new DefaultTransportRequest(urlInfo, headers, getHttpRequestHeaders(headers), + transport, type, getMessageCodec())); } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java index d0fc7df3195..94bd3c65c67 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java @@ -47,6 +47,13 @@ public interface TransportRequest { */ HttpHeaders getHandshakeHeaders(); + /** + * Return the headers to add to all other HTTP requests besides the handshake + * request such XHR receive and send requests. + * @since 4.2 + */ + HttpHeaders getHttpRequestHeaders(); + /** * Return the transport URL for the given transport. * For an {@link XhrTransport} this is the URL for receiving messages. diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java index 68c064fa722..8f3757833ca 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java @@ -134,11 +134,11 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra HttpHeaders handshakeHeaders, XhrClientSockJsSession session, SettableListenableFuture connectFuture) { - executeReceiveRequest(receiveUrl, handshakeHeaders, session, connectFuture); + executeReceiveRequest(request, receiveUrl, handshakeHeaders, session, connectFuture); } - private void executeReceiveRequest(final URI url, final HttpHeaders headers, - final XhrClientSockJsSession session, + private void executeReceiveRequest(final TransportRequest transportRequest, + final URI url, final HttpHeaders headers, final XhrClientSockJsSession session, final SettableListenableFuture connectFuture) { if (logger.isTraceEnabled()) { @@ -154,8 +154,9 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra HttpString headerName = HttpString.tryFromString(HttpHeaders.HOST); request.getRequestHeaders().add(headerName, url.getHost()); addHttpHeaders(request, headers); - connection.sendRequest(request, createReceiveCallback(url, - getRequestHeaders(), session, connectFuture)); + HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders(); + connection.sendRequest(request, createReceiveCallback(transportRequest, + url, httpHeaders, session, connectFuture)); } @Override @@ -175,8 +176,8 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra } } - private ClientCallback createReceiveCallback(final URI url, final HttpHeaders headers, - final XhrClientSockJsSession sockJsSession, + private ClientCallback createReceiveCallback(final TransportRequest transportRequest, + final URI url, final HttpHeaders headers, final XhrClientSockJsSession sockJsSession, final SettableListenableFuture connectFuture) { return new ClientCallback() { @@ -194,8 +195,9 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra onFailure(new HttpServerErrorException(status, "Unexpected XHR receive status")); } else { - SockJsResponseListener listener = new SockJsResponseListener(result.getConnection(), - url, headers, sockJsSession, connectFuture); + SockJsResponseListener listener = new SockJsResponseListener( + transportRequest, result.getConnection(), url, headers, + sockJsSession, connectFuture); listener.setup(result.getResponseChannel()); } if (logger.isTraceEnabled()) { @@ -254,8 +256,8 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra } @Override - protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { - return executeRequest(infoUrl, Methods.GET, getRequestHeaders(), null); + protected ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { + return executeRequest(infoUrl, Methods.GET, headers, null); } @Override @@ -360,6 +362,8 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra private class SockJsResponseListener implements ChannelListener { + private final TransportRequest request; + private final ClientConnection connection; private final URI url; @@ -372,10 +376,12 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - public SockJsResponseListener(ClientConnection connection, URI url, + + public SockJsResponseListener(TransportRequest request, ClientConnection connection, URI url, HttpHeaders headers, XhrClientSockJsSession sockJsSession, SettableListenableFuture connectFuture) { + this.request = request; this.connection = connection; this.url = url; this.headers = headers; @@ -455,7 +461,7 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra logger.trace("XHR receive request completed."); } IoUtils.safeClose(this.connection); - executeReceiveRequest(this.url, this.headers, this.session, this.connectFuture); + executeReceiveRequest(this.request, this.url, this.headers, this.session, this.connectFuture); } public void onFailure(Throwable failure) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java index 1381f33ee4f..59c7d25b479 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ import java.net.InetSocketAddress; import java.net.URI; import java.util.List; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; import org.springframework.util.Assert; import org.springframework.util.concurrent.SettableListenableFuture; import org.springframework.web.socket.CloseStatus; @@ -39,10 +41,14 @@ import org.springframework.web.socket.sockjs.transport.TransportType; */ public class XhrClientSockJsSession extends AbstractClientSockJsSession { - private final URI sendUrl; - private final XhrTransport transport; + private HttpHeaders headers; + + private HttpHeaders sendHeaders; + + private final URI sendUrl; + private int textMessageSizeLimit = -1; private int binaryMessageSizeLimit = -1; @@ -53,11 +59,21 @@ public class XhrClientSockJsSession extends AbstractClientSockJsSession { super(request, handler, connectFuture); Assert.notNull(transport, "'restTemplate' is required"); - this.sendUrl = request.getSockJsUrlInfo().getTransportUrl(TransportType.XHR_SEND); this.transport = transport; + this.headers = request.getHttpRequestHeaders(); + this.sendHeaders = new HttpHeaders(); + if (this.headers != null) { + this.sendHeaders.putAll(this.headers); + } + this.sendHeaders.setContentType(MediaType.APPLICATION_JSON); + this.sendUrl = request.getSockJsUrlInfo().getTransportUrl(TransportType.XHR_SEND); } + public HttpHeaders getHeaders() { + return this.headers; + } + @Override public InetSocketAddress getLocalAddress() { return null; @@ -100,7 +116,7 @@ public class XhrClientSockJsSession extends AbstractClientSockJsSession { @Override protected void sendInternal(TextMessage message) { - this.transport.executeSendRequest(this.sendUrl, message); + this.transport.executeSendRequest(this.sendUrl, this.sendHeaders, message); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java index 6fcf7f16518..d5725ed54ab 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java @@ -17,14 +17,14 @@ package org.springframework.web.socket.sockjs.client; import java.net.URI; +import org.springframework.http.HttpHeaders; import org.springframework.web.socket.TextMessage; /** * A SockJS {@link Transport} that uses HTTP requests to simulate a WebSocket * interaction. The {@code connect} method of the base {@code Transport} interface * is used to receive messages from the server while the - * {@link #executeSendRequest(java.net.URI, org.springframework.web.socket.TextMessage) - * executeSendRequest(URI, TextMessage)} method here is used to send messages. + * {@link #executeSendRequest} method here is used to send messages. * * @author Rossen Stoyanchev * @since 4.1 @@ -35,7 +35,6 @@ public interface XhrTransport extends Transport, InfoReceiver { * An {@code XhrTransport} supports both the "xhr_streaming" and "xhr" SockJS * server transports. From a client perspective there is no implementation * difference. - * *

By default an {@code XhrTransport} will be used with "xhr_streaming" * first and then with "xhr", if the streaming fails to connect. In some * cases it may be useful to suppress streaming so that only "xhr" is used. @@ -44,9 +43,10 @@ public interface XhrTransport extends Transport, InfoReceiver { /** * Execute a request to send the message to the server. + *

Note that as of 4.2 this method accepts a {@code headers} parameter. * @param transportUrl the URL for sending messages. * @param message the message to send */ - void executeSendRequest(URI transportUrl, TextMessage message); + void executeSendRequest(URI transportUrl, HttpHeaders headers, TextMessage message); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java index 0696bfe4be3..4902fd5a7b7 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java @@ -49,6 +49,8 @@ import org.junit.rules.TestName; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpHeaders; +import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.tests.Assume; import org.springframework.tests.TestGroup; @@ -100,7 +102,7 @@ public abstract class AbstractSockJsIntegrationTests { @BeforeClass public static void performanceTestGroupAssumption() throws Exception { - Assume.group(TestGroup.PERFORMANCE); +// Assume.group(TestGroup.PERFORMANCE); } @@ -164,19 +166,36 @@ public abstract class AbstractSockJsIntegrationTests { @Test public void echoWebSocket() throws Exception { - testEcho(100, createWebSocketTransport()); + testEcho(100, createWebSocketTransport(), null); } @Test public void echoXhrStreaming() throws Exception { - testEcho(100, createXhrTransport()); + testEcho(100, createXhrTransport(), null); } @Test public void echoXhr() throws Exception { AbstractXhrTransport xhrTransport = createXhrTransport(); xhrTransport.setXhrStreamingDisabled(true); - testEcho(100, xhrTransport); + testEcho(100, xhrTransport, null); + } + + // SPR-13254 + + @Test + public void echoXhrWithHeaders() throws Exception { + AbstractXhrTransport xhrTransport = createXhrTransport(); + xhrTransport.setXhrStreamingDisabled(true); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); + headers.add("auth", "123"); + testEcho(10, xhrTransport, headers); + + for (Map.Entry entry : this.testFilter.requests.entrySet()) { + HttpHeaders httpHeaders = entry.getValue(); + assertEquals("No auth header for: " + entry.getKey(), "123", httpHeaders.getFirst("auth")); + } } @Test @@ -246,14 +265,15 @@ public abstract class AbstractSockJsIntegrationTests { } - private void testEcho(int messageCount, Transport transport) throws Exception { + private void testEcho(int messageCount, Transport transport, WebSocketHttpHeaders headers) throws Exception { List messages = new ArrayList<>(); for (int i = 0; i < messageCount; i++) { messages.add(new TextMessage("m" + i)); } TestClientHandler handler = new TestClientHandler(); initSockJsClient(transport); - WebSocketSession session = this.sockJsClient.doHandshake(handler, this.baseUrl + "/echo").get(); + URI url = new URI(this.baseUrl + "/echo"); + WebSocketSession session = this.sockJsClient.doHandshake(handler, headers, url).get(); for (TextMessage message : messages) { session.sendMessage(message); } @@ -386,7 +406,7 @@ public abstract class AbstractSockJsIntegrationTests { private static class TestFilter implements Filter { - private final List requests = new ArrayList<>(); + private final Map requests = new HashMap<>(); private final Map sleepDelayMap = new HashMap<>(); @@ -397,10 +417,13 @@ public abstract class AbstractSockJsIntegrationTests { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - this.requests.add(request); + HttpServletRequest httpRequest = (HttpServletRequest) request; + String uri = httpRequest.getRequestURI(); + HttpHeaders headers = new ServletServerHttpRequest(httpRequest).getHeaders(); + this.requests.put(uri, headers); for (String suffix : this.sleepDelayMap.keySet()) { - if (((HttpServletRequest) request).getRequestURI().endsWith(suffix)) { + if ((httpRequest).getRequestURI().endsWith(suffix)) { try { Thread.sleep(this.sleepDelayMap.get(suffix)); break; @@ -411,7 +434,7 @@ public abstract class AbstractSockJsIntegrationTests { } } for (String suffix : this.sendErrorMap.keySet()) { - if (((HttpServletRequest) request).getRequestURI().endsWith(suffix)) { + if ((httpRequest).getRequestURI().endsWith(suffix)) { ((HttpServletResponse) response).sendError(this.sendErrorMap.get(suffix)); return; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java index 8e153335bcd..791aa387af5 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,7 +64,7 @@ public class ClientSockJsSessionTests { public void setup() throws Exception { SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); Transport transport = mock(Transport.class); - TransportRequest request = new DefaultTransportRequest(urlInfo, null, transport, TransportType.XHR, CODEC); + TransportRequest request = new DefaultTransportRequest(urlInfo, null, null, transport, TransportType.XHR, CODEC); this.handler = mock(WebSocketHandler.class); this.connectFuture = new SettableListenableFuture<>(); this.session = new TestClientSockJsSession(request, this.handler, this.connectFuture); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java index d9849ca0ced..8875081a992 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java @@ -127,7 +127,7 @@ public class DefaultTransportRequestTests { protected DefaultTransportRequest createTransportRequest(Transport transport, TransportType type) throws Exception { SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); - return new DefaultTransportRequest(urlInfo, new HttpHeaders(), transport, type, CODEC); + return new DefaultTransportRequest(urlInfo, new HttpHeaders(), new HttpHeaders(), transport, type, CODEC); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java index 74fa763d5c9..ca8f2c83f62 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -182,7 +182,8 @@ public class RestTemplateXhrTransportTests { SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); HttpHeaders headers = new HttpHeaders(); headers.add("h-foo", "h-bar"); - TransportRequest request = new DefaultTransportRequest(urlInfo, headers, transport, TransportType.XHR, CODEC); + TransportRequest request = new DefaultTransportRequest(urlInfo, headers, headers, + transport, TransportType.XHR, CODEC); return transport.connect(request, this.webSocketHandler); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java index 918edddbbdb..7db25a4ea93 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,22 +16,34 @@ package org.springframework.web.socket.sockjs.client; +import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; import java.util.List; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.util.concurrent.ListenableFutureCallback; import org.springframework.web.client.HttpServerErrorException; import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.sockjs.client.TestTransport.XhrTestTransport; -import static org.junit.Assert.*; -import static org.mockito.BDDMockito.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.BDDMockito.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.times; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.verifyNoMoreInteractions; +import static org.mockito.BDDMockito.when; /** * Unit tests for {@link org.springframework.web.socket.sockjs.client.SockJsClient}. @@ -102,11 +114,51 @@ public class SockJsClientTests { assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr")); } + // SPR-13254 + + @Test + public void connectWithHandshakeHeaders() throws Exception { + ArgumentCaptor headersCaptor = setupInfoRequest(false); + this.xhrTransport.setStreamingDisabled(true); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); + headers.set("foo", "bar"); + headers.set("auth", "123"); + this.sockJsClient.doHandshake(handler, headers, new URI(URL)).addCallback(this.connectCallback); + + HttpHeaders httpHeaders = headersCaptor.getValue(); + assertEquals(2, httpHeaders.size()); + assertEquals("bar", httpHeaders.getFirst("foo")); + assertEquals("123", httpHeaders.getFirst("auth")); + + httpHeaders = this.xhrTransport.getRequest().getHttpRequestHeaders(); + assertEquals(2, httpHeaders.size()); + assertEquals("bar", httpHeaders.getFirst("foo")); + assertEquals("123", httpHeaders.getFirst("auth")); + } + + @Test + public void connectAndUseSubsetOfHandshakeHeadersForHttpRequests() throws Exception { + ArgumentCaptor headersCaptor = setupInfoRequest(false); + this.xhrTransport.setStreamingDisabled(true); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); + headers.set("foo", "bar"); + headers.set("auth", "123"); + this.sockJsClient.setHttpHeaderNames("auth"); + this.sockJsClient.doHandshake(handler, headers, new URI(URL)).addCallback(this.connectCallback); + + assertEquals(1, headersCaptor.getValue().size()); + assertEquals("123", headersCaptor.getValue().getFirst("auth")); + assertEquals(1, this.xhrTransport.getRequest().getHttpRequestHeaders().size()); + assertEquals("123", this.xhrTransport.getRequest().getHttpRequestHeaders().getFirst("auth")); + } + @Test public void connectSockJsInfo() throws Exception { setupInfoRequest(true); this.sockJsClient.doHandshake(handler, URL); - verify(this.infoReceiver, times(1)).executeInfoRequest(any()); + verify(this.infoReceiver, times(1)).executeInfoRequest(any(), any()); } @Test @@ -115,22 +167,27 @@ public class SockJsClientTests { this.sockJsClient.doHandshake(handler, URL); this.sockJsClient.doHandshake(handler, URL); this.sockJsClient.doHandshake(handler, URL); - verify(this.infoReceiver, times(1)).executeInfoRequest(any()); + verify(this.infoReceiver, times(1)).executeInfoRequest(any(), any()); } @Test public void connectInfoRequestFailure() throws URISyntaxException { HttpServerErrorException exception = new HttpServerErrorException(HttpStatus.SERVICE_UNAVAILABLE); - given(this.infoReceiver.executeInfoRequest(any())).willThrow(exception); + given(this.infoReceiver.executeInfoRequest(any(), any())).willThrow(exception); this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); verify(this.connectCallback).onFailure(exception); assertFalse(this.webSocketTransport.invoked()); assertFalse(this.xhrTransport.invoked()); } - private void setupInfoRequest(boolean webSocketEnabled) { - given(this.infoReceiver.executeInfoRequest(any())).willReturn("{\"entropy\":123," + - "\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":" + webSocketEnabled + "}"); + private ArgumentCaptor setupInfoRequest(boolean webSocketEnabled) { + ArgumentCaptor headersCaptor = ArgumentCaptor.forClass(HttpHeaders.class); + when(this.infoReceiver.executeInfoRequest(any(), headersCaptor.capture())).thenReturn( + "{\"entropy\":123," + + "\"origins\":[\"*:*\"]," + + "\"cookie_needed\":true," + + "\"websocket\":" + webSocketEnabled + "}"); + return headersCaptor; } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java index 449f429e1a1..329b72c6990 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,9 +18,12 @@ package org.springframework.web.socket.sockjs.client; import java.net.URI; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.mockito.ArgumentCaptor; + +import org.springframework.http.HttpHeaders; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.util.concurrent.ListenableFutureCallback; import org.springframework.web.socket.TextMessage; @@ -28,7 +31,8 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.sockjs.transport.TransportType; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Test SockJS Transport. @@ -51,7 +55,7 @@ class TestTransport implements Transport { @Override public List getTransportTypes() { - return Arrays.asList(TransportType.WEBSOCKET); + return Collections.singletonList(TransportType.WEBSOCKET); } public TransportRequest getRequest() { @@ -95,7 +99,7 @@ class TestTransport implements Transport { @Override public List getTransportTypes() { return (isXhrStreamingDisabled() ? - Arrays.asList(TransportType.XHR) : + Collections.singletonList(TransportType.XHR) : Arrays.asList(TransportType.XHR_STREAMING, TransportType.XHR)); } @@ -109,11 +113,11 @@ class TestTransport implements Transport { } @Override - public void executeSendRequest(URI transportUrl, TextMessage message) { + public void executeSendRequest(URI transportUrl, HttpHeaders headers, TextMessage message) { } @Override - public String executeInfoRequest(URI infoUrl) { + public String executeInfoRequest(URI infoUrl, HttpHeaders headers) { return null; } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java index b54c93079a0..778831a1ff0 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,25 +46,25 @@ public class XhrTransportTests { public void infoResponse() throws Exception { TestXhrTransport transport = new TestXhrTransport(); transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK); - assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"))); + assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"), null)); } @Test(expected = HttpServerErrorException.class) public void infoResponseError() throws Exception { TestXhrTransport transport = new TestXhrTransport(); transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST); - assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"))); + assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"), null)); } @Test public void sendMessage() throws Exception { HttpHeaders requestHeaders = new HttpHeaders(); requestHeaders.set("foo", "bar"); + requestHeaders.setContentType(MediaType.APPLICATION_JSON); TestXhrTransport transport = new TestXhrTransport(); - transport.setRequestHeaders(requestHeaders); transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT); URI url = new URI("http://example.com"); - transport.executeSendRequest(url, new TextMessage("payload")); + transport.executeSendRequest(url, requestHeaders, new TextMessage("payload")); assertEquals(2, transport.actualSendRequestHeaders.size()); assertEquals("bar", transport.actualSendRequestHeaders.getFirst("foo")); assertEquals(MediaType.APPLICATION_JSON, transport.actualSendRequestHeaders.getContentType()); @@ -75,9 +75,10 @@ public class XhrTransportTests { TestXhrTransport transport = new TestXhrTransport(); transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST); URI url = new URI("http://example.com"); - transport.executeSendRequest(url, new TextMessage("payload")); + transport.executeSendRequest(url, null, new TextMessage("payload")); } + @SuppressWarnings("deprecation") @Test public void connect() throws Exception { HttpHeaders handshakeHeaders = new HttpHeaders(); @@ -101,6 +102,7 @@ public class XhrTransportTests { verify(request).addTimeoutTask(captor.capture()); verify(request).getTransportUrl(); verify(request).getHandshakeHeaders(); + verify(request).getHttpRequestHeaders(); verifyNoMoreInteractions(request); assertEquals(2, transport.actualHandshakeHeaders.size()); @@ -127,7 +129,7 @@ public class XhrTransportTests { @Override - protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { + protected ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { return this.infoResponseToReturn; }