From 9ca4672300aa59949ac1073c56feedee783af6ec Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 1 May 2013 14:18:25 -0400 Subject: [PATCH] Fix handshake handling issue --- .../websocket/CloseStatus.java | 4 ++++ .../endpoint/StandardWebSocketClient.java | 20 ++++++++++++++++++- .../client/jetty/JettyWebSocketClient.java | 2 ++ .../server/DefaultHandshakeHandler.java | 9 ++++++--- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/websocket/CloseStatus.java b/spring-websocket/src/main/java/org/springframework/websocket/CloseStatus.java index dafcb988e7b..8774fe46f95 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/CloseStatus.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/CloseStatus.java @@ -198,6 +198,10 @@ public final class CloseStatus { return (this.code == otherStatus.code && ObjectUtils.nullSafeEquals(this.reason, otherStatus.reason)); } + public boolean equalsCode(CloseStatus other) { + return this.code == other.code; + } + @Override public String toString() { return "CloseStatus [code=" + this.code + ", reason=" + this.reason + "]"; diff --git a/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/StandardWebSocketClient.java index ffb043c340b..100ef641649 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/StandardWebSocketClient.java @@ -27,9 +27,12 @@ import javax.websocket.ClientEndpointConfig; import javax.websocket.ClientEndpointConfig.Configurator; import javax.websocket.ContainerProvider; import javax.websocket.Endpoint; +import javax.websocket.HandshakeResponse; import javax.websocket.Session; import javax.websocket.WebSocketContainer; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpHeaders; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.websocket.WebSocketHandler; @@ -47,6 +50,8 @@ import org.springframework.websocket.client.WebSocketConnectFailureException; */ public class StandardWebSocketClient implements WebSocketClient { + private static final Log logger = LogFactory.getLog(StandardWebSocketClient.class); + private static final Set EXCLUDED_HEADERS = new HashSet( Arrays.asList("Sec-WebSocket-Accept", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key", "Sec-WebSocket-Protocol", "Sec-WebSocket-Version")); @@ -83,9 +88,22 @@ public class StandardWebSocketClient implements WebSocketClient { public void beforeRequest(Map> headers) { for (String headerName : httpHeaders.keySet()) { if (!EXCLUDED_HEADERS.contains(headerName)) { - headers.put(headerName, httpHeaders.get(headerName)); + List value = httpHeaders.get(headerName); + if (logger.isTraceEnabled()) { + logger.trace("Adding header [" + headerName + "=" + value + "]"); + } + headers.put(headerName, value); } } + if (logger.isTraceEnabled()) { + logger.trace("Handshake request headers: " + headers); + } + } + @Override + public void afterResponse(HandshakeResponse handshakeResponse) { + if (logger.isTraceEnabled()) { + logger.trace("Handshake response headers: " + handshakeResponse.getHeaders()); + } } }); } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/client/jetty/JettyWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/websocket/client/jetty/JettyWebSocketClient.java index 210e0dfc131..ae1cc25cc93 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/client/jetty/JettyWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/client/jetty/JettyWebSocketClient.java @@ -134,6 +134,8 @@ public class JettyWebSocketClient implements WebSocketClient, SmartLifecycle { public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri) throws WebSocketConnectFailureException { + // TODO: populate headers + JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler); try { diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java index e557a0cc120..b369aca0850 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.nio.charset.Charset; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -34,6 +35,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.websocket.WebSocketHandler; @@ -53,7 +55,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { protected Log logger = LogFactory.getLog(getClass()); - private List supportedProtocols; + private List supportedProtocols = new ArrayList(); private RequestUpgradeStrategy requestUpgradeStrategy; @@ -101,7 +103,8 @@ public class DefaultHandshakeHandler implements HandshakeHandler { handleInvalidUpgradeHeader(request, response); return false; } - if (!request.getHeaders().getConnection().contains("Upgrade")) { + if (!request.getHeaders().getConnection().contains("Upgrade") && + !request.getHeaders().getConnection().contains("upgrade")) { handleInvalidConnectHeader(request, response); return false; } @@ -188,7 +191,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { } protected String selectProtocol(List requestedProtocols) { - if (requestedProtocols != null) { + if (CollectionUtils.isEmpty(requestedProtocols)) { for (String protocol : requestedProtocols) { if (this.supportedProtocols.contains(protocol)) { return protocol;