Polish WebSocketSession

Update methods available on WebSocketSession interface.
Introduce DelegatingWebSocketSession interface.
This commit is contained in:
Rossen Stoyanchev 2013-08-09 09:38:13 -04:00
parent 14ac023e01
commit 01feae0ad5
36 changed files with 735 additions and 590 deletions

View File

@ -30,6 +30,7 @@ import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
@ -133,7 +134,7 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHan
protected final SubProtocolHandler getProtocolHandler(WebSocketSession session) {
SubProtocolHandler handler;
String protocol = session.getAcceptedProtocol();
if (protocol != null) {
if (!StringUtils.isEmpty(protocol)) {
handler = this.protocolHandlers.get(protocol);
Assert.state(handler != null,
"No handler for sub-protocol '" + protocol + "', handlers=" + this.protocolHandlers);

View File

@ -91,7 +91,18 @@ public class SubProtocolWebSocketHandlerTests {
}
@Test
public void noSubProtocol() throws Exception {
public void nullSubProtocol() throws Exception {
this.webSocketHandler.setDefaultProtocolHandler(defaultHandler);
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.defaultHandler).afterSessionStarted(session, this.channel);
verify(this.stompHandler, times(0)).afterSessionStarted(session, this.channel);
verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.channel);
}
@Test
public void emptySubProtocol() throws Exception {
this.session.setAcceptedProtocol("");
this.webSocketHandler.setDefaultProtocolHandler(defaultHandler);
this.webSocketHandler.afterConnectionEstablished(session);

View File

@ -16,6 +16,7 @@
package org.springframework.http.server;
import java.net.InetSocketAddress;
import java.security.Principal;
import java.util.Map;
@ -51,14 +52,14 @@ public interface ServerHttpRequest extends HttpRequest, HttpInputMessage {
Principal getPrincipal();
/**
* Return the host name of the endpoint on the other end.
* Return the address on which the request was received.
*/
String getRemoteHostName();
InetSocketAddress getLocalAddress();
/**
* Return the IP address of the endpoint on the other end.
* Return the address of the remote client.
*/
String getRemoteAddress();
InetSocketAddress getRemoteAddress();
/**
* Return a control that allows putting the request in asynchronous mode so the

View File

@ -22,6 +22,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
@ -147,13 +148,13 @@ public class ServletServerHttpRequest implements ServerHttpRequest {
}
@Override
public String getRemoteHostName() {
return this.servletRequest.getRemoteHost();
public InetSocketAddress getLocalAddress() {
return new InetSocketAddress(this.servletRequest.getLocalName(), this.servletRequest.getLocalPort());
}
@Override
public String getRemoteAddress() {
return this.servletRequest.getRemoteAddr();
public InetSocketAddress getRemoteAddress() {
return new InetSocketAddress(this.servletRequest.getRemoteHost(), this.servletRequest.getRemotePort());
}
@Override

View File

@ -17,9 +17,12 @@
package org.springframework.web.socket;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import org.springframework.http.HttpHeaders;
/**
* A WebSocket session abstraction. Allows sending messages over a WebSocket connection
* and closing it.
@ -29,6 +32,7 @@ import java.security.Principal;
*/
public interface WebSocketSession {
/**
* Return a unique session identifier.
*/
@ -40,9 +44,9 @@ public interface WebSocketSession {
URI getUri();
/**
* Return whether the underlying socket is using a secure transport.
* Return the headers used in the handshake request.
*/
boolean isSecure();
HttpHeaders getHandshakeHeaders();
/**
* Return a {@link java.security.Principal} instance containing the name of the
@ -52,17 +56,18 @@ public interface WebSocketSession {
Principal getPrincipal();
/**
* Return the host name of the endpoint on the other end.
* Return the address on which the request was received.
*/
String getRemoteHostName();
InetSocketAddress getLocalAddress();
/**
* Return the IP address of the endpoint on the other end.
* Return the address of the remote client.
*/
String getRemoteAddress();
InetSocketAddress getRemoteAddress();
/**
* Return the negotiated sub-protocol or {@code null} if none was specified.
* Return the negotiated sub-protocol or {@code null} if none was specified or
* negotiated successfully.
*/
String getAcceptedProtocol();

View File

@ -32,19 +32,41 @@ import org.springframework.web.socket.WebSocketSession;
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractWebSocketSesssionAdapter<T> implements ConfigurableWebSocketSession {
public abstract class AbstractWebSocketSesssion<T> implements DelegatingWebSocketSession<T> {
protected final Log logger = LogFactory.getLog(getClass());
private T delegateSession;
public abstract void initSession(T session);
/**
* @return the WebSocket session to delegate to
*/
public T getDelegateSession() {
return this.delegateSession;
}
@Override
public void afterSessionInitialized(T session) {
Assert.notNull(session, "session must not be null");
this.delegateSession = session;
}
protected final void checkDelegateSessionInitialized() {
Assert.state(this.delegateSession != null, "WebSocket session is not yet initialized");
}
@Override
public final void sendMessage(WebSocketMessage message) throws IOException {
checkDelegateSessionInitialized();
Assert.isTrue(isOpen(), "Cannot send message after connection closed.");
if (logger.isTraceEnabled()) {
logger.trace("Sending " + message + ", " + this);
}
Assert.isTrue(isOpen(), "Cannot send message after connection closed.");
if (message instanceof TextMessage) {
sendTextMessage((TextMessage) message);
}
@ -60,13 +82,15 @@ public abstract class AbstractWebSocketSesssionAdapter<T> implements Configurabl
protected abstract void sendBinaryMessage(BinaryMessage message) throws IOException ;
@Override
public void close() throws IOException {
public final void close() throws IOException {
close(CloseStatus.NORMAL);
}
@Override
public final void close(CloseStatus status) throws IOException {
checkDelegateSessionInitialized();
if (logger.isDebugEnabled()) {
logger.debug("Closing " + this);
}
@ -75,6 +99,7 @@ public abstract class AbstractWebSocketSesssionAdapter<T> implements Configurabl
protected abstract void closeInternal(CloseStatus status) throws IOException;
@Override
public String toString() {
return "WebSocket session id=" + getId();

View File

@ -16,34 +16,25 @@
package org.springframework.web.socket.adapter;
import java.net.URI;
import java.security.Principal;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
/**
* A WebSocketSession with configurable properties.
* A contract for {@link WebSocketSession} implementations that delegate to another
* WebSocket session (e.g. a native session).
*
* @param T the type of the delegate WebSocket session
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface ConfigurableWebSocketSession extends WebSocketSession {
public interface DelegatingWebSocketSession<T> extends WebSocketSession {
void setUri(URI uri);
void setRemoteHostName(String name);
void setRemoteAddress(String address);
void setPrincipal(Principal principal);
/**
* Set the protocol accepted as part of the WebSocket handshake. This property can be
* used when the WebSocket handshake is performed through
* {@link DefaultHandshakeHandler} rather than the underlying WebSocket runtime, or
* when there is no WebSocket handshake (e.g. SockJS HTTP fallback options)
* Invoked when the delegate WebSocket session has been initialized.
*/
void setAcceptedProtocol(String protocol);
void afterSessionInitialized(T session);
}

View File

@ -28,21 +28,22 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator;
/**
* Adapts {@link WebSocketHandler} to the Jetty 9 {@link WebSocketListener}.
* Adapts {@link WebSocketHandler} to the Jetty 9 WebSocket API.
*
* @author Phillip Webb
* @author Rossen Stoyanchev
* @since 4.0
*/
public class JettyWebSocketListenerAdapter implements WebSocketListener {
public class JettyWebSocketHandlerAdapter implements WebSocketListener {
private static final Log logger = LogFactory.getLog(JettyWebSocketListenerAdapter.class);
private static final Log logger = LogFactory.getLog(JettyWebSocketHandlerAdapter.class);
private final WebSocketHandler webSocketHandler;
private final JettyWebSocketSessionAdapter wsSession;
private final JettyWebSocketSession wsSession;
public JettyWebSocketListenerAdapter(WebSocketHandler webSocketHandler, JettyWebSocketSessionAdapter wsSession) {
public JettyWebSocketHandlerAdapter(WebSocketHandler webSocketHandler, JettyWebSocketSession wsSession) {
Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
Assert.notNull(wsSession, "wsSession must not be null");
this.webSocketHandler = webSocketHandler;
@ -52,8 +53,8 @@ public class JettyWebSocketListenerAdapter implements WebSocketListener {
@Override
public void onWebSocketConnect(Session session) {
this.wsSession.initSession(session);
try {
this.wsSession.afterSessionInitialized(session);
this.webSocketHandler.afterConnectionEstablished(this.wsSession);
}
catch (Throwable t) {

View File

@ -0,0 +1,121 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import org.springframework.http.HttpHeaders;
import org.springframework.util.ObjectUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* A {@link WebSocketSession} for use with the Jetty 9 WebSocket API.
*
* @author Phillip Webb
* @author Rossen Stoyanchev
* @since 4.0
*/
public class JettyWebSocketSession extends AbstractWebSocketSesssion<org.eclipse.jetty.websocket.api.Session> {
private HttpHeaders headers;
private final Principal principal;
/**
* Class constructor.
*
* @param principal the user associated with the session, or {@code null}
*/
public JettyWebSocketSession(Principal principal) {
this.principal = principal;
}
@Override
public String getId() {
checkDelegateSessionInitialized();
return ObjectUtils.getIdentityHexString(getDelegateSession());
}
@Override
public URI getUri() {
checkDelegateSessionInitialized();
return getDelegateSession().getUpgradeRequest().getRequestURI();
}
@Override
public HttpHeaders getHandshakeHeaders() {
checkDelegateSessionInitialized();
if (this.headers == null) {
this.headers = new HttpHeaders();
this.headers.putAll(getDelegateSession().getUpgradeRequest().getHeaders());
this.headers = HttpHeaders.readOnlyHttpHeaders(headers);
}
return this.headers;
}
@Override
public Principal getPrincipal() {
return this.principal;
}
@Override
public InetSocketAddress getLocalAddress() {
checkDelegateSessionInitialized();
return getDelegateSession().getLocalAddress();
}
@Override
public InetSocketAddress getRemoteAddress() {
checkDelegateSessionInitialized();
return getDelegateSession().getRemoteAddress();
}
@Override
public String getAcceptedProtocol() {
checkDelegateSessionInitialized();
return getDelegateSession().getUpgradeResponse().getAcceptedSubProtocol();
}
@Override
public boolean isOpen() {
return ((getDelegateSession() != null) && getDelegateSession().isOpen());
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
getDelegateSession().getRemote().sendString(message.getPayload());
}
@Override
protected void sendBinaryMessage(BinaryMessage message) throws IOException {
getDelegateSession().getRemote().sendBytes(message.getPayload());
}
@Override
protected void closeInternal(CloseStatus status) throws IOException {
getDelegateSession().close(status.getCode(), status.getReason());
}
}

View File

@ -1,144 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* Adapts a Jetty {@link org.eclipse.jetty.websocket.api.Session} to
* {@link WebSocketSession}.
*
* @author Phillip Webb
* @author Rossen Stoyanchev
* @since 4.0
*/
public class JettyWebSocketSessionAdapter
extends AbstractWebSocketSesssionAdapter<org.eclipse.jetty.websocket.api.Session> {
private Session session;
private Principal principal;
private String protocol;
@Override
public void initSession(Session session) {
Assert.notNull(session, "session must not be null");
this.session = session;
if (this.protocol == null) {
UpgradeResponse response = session.getUpgradeResponse();
if ((response != null) && response.getAcceptedSubProtocol() != null) {
this.protocol = response.getAcceptedSubProtocol();
}
}
}
@Override
public String getId() {
return ObjectUtils.getIdentityHexString(this.session);
}
@Override
public boolean isSecure() {
return this.session.isSecure();
}
@Override
public URI getUri() {
return this.session.getUpgradeRequest().getRequestURI();
}
@Override
public void setUri(URI uri) {
}
@Override
public Principal getPrincipal() {
return this.principal;
}
@Override
public void setPrincipal(Principal principal) {
this.principal = principal;
}
@Override
public String getRemoteHostName() {
return this.session.getRemoteAddress().getHostName();
}
@Override
public void setRemoteHostName(String address) {
// ignore
}
@Override
public String getRemoteAddress() {
InetSocketAddress address = this.session.getRemoteAddress();
return address.isUnresolved() ? null : address.getAddress().getHostAddress();
}
@Override
public void setRemoteAddress(String address) {
// ignore
}
@Override
public String getAcceptedProtocol() {
return this.protocol;
}
@Override
public void setAcceptedProtocol(String protocol) {
this.protocol = protocol;
}
@Override
public boolean isOpen() {
return this.session.isOpen();
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
this.session.getRemote().sendString(message.getPayload());
}
@Override
protected void sendBinaryMessage(BinaryMessage message) throws IOException {
this.session.getRemote().sendBytes(message.getPayload());
}
@Override
protected void closeInternal(CloseStatus status) throws IOException {
this.session.close(status.getCode(), status.getReason());
}
}

View File

@ -33,21 +33,21 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator;
/**
* Adapts a {@link WebSocketHandler} to a standard {@link Endpoint}.
* Adapts a {@link WebSocketHandler} to the standard WebSocket for Java API.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StandardEndpointAdapter extends Endpoint {
public class StandardWebSocketHandlerAdapter extends Endpoint {
private static final Log logger = LogFactory.getLog(StandardEndpointAdapter.class);
private static final Log logger = LogFactory.getLog(StandardWebSocketHandlerAdapter.class);
private final WebSocketHandler handler;
private final StandardWebSocketSessionAdapter wsSession;
private final StandardWebSocketSession wsSession;
public StandardEndpointAdapter(WebSocketHandler handler, StandardWebSocketSessionAdapter wsSession) {
public StandardWebSocketHandlerAdapter(WebSocketHandler handler, StandardWebSocketSession wsSession) {
Assert.notNull(handler, "handler must not be null");
Assert.notNull(wsSession, "wsSession must not be null");
this.handler = handler;
@ -58,7 +58,7 @@ public class StandardEndpointAdapter extends Endpoint {
@Override
public void onOpen(final javax.websocket.Session session, EndpointConfig config) {
this.wsSession.initSession(session);
this.wsSession.afterSessionInitialized(session);
if (this.handler.supportsPartialMessages()) {
session.addMessageHandler(new MessageHandler.Partial<String>() {

View File

@ -0,0 +1,123 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import org.springframework.http.HttpHeaders;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* A {@link WebSocketSession} for use with the standard WebSocket for Java API.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StandardWebSocketSession extends AbstractWebSocketSesssion<javax.websocket.Session> {
private final HttpHeaders headers;
private final InetSocketAddress localAddress;
private final InetSocketAddress remoteAddress;
/**
* Class constructor.
*
* @param handshakeHeaders the headers of the handshake request
*/
public StandardWebSocketSession(HttpHeaders handshakeHeaders, InetSocketAddress localAddress,
InetSocketAddress remoteAddress) {
handshakeHeaders = (handshakeHeaders != null) ? handshakeHeaders : new HttpHeaders();
this.headers = HttpHeaders.readOnlyHttpHeaders(handshakeHeaders);
this.localAddress = localAddress;
this.remoteAddress = remoteAddress;
}
@Override
public String getId() {
checkDelegateSessionInitialized();
return getDelegateSession().getId();
}
@Override
public URI getUri() {
checkDelegateSessionInitialized();
return getDelegateSession().getRequestURI();
}
@Override
public HttpHeaders getHandshakeHeaders() {
return this.headers;
}
@Override
public Principal getPrincipal() {
checkDelegateSessionInitialized();
return getDelegateSession().getUserPrincipal();
}
@Override
public InetSocketAddress getLocalAddress() {
return this.localAddress;
}
@Override
public InetSocketAddress getRemoteAddress() {
return this.remoteAddress;
}
@Override
public String getAcceptedProtocol() {
checkDelegateSessionInitialized();
String protocol = getDelegateSession().getNegotiatedSubprotocol();
return StringUtils.isEmpty(protocol)? null : protocol;
}
@Override
public boolean isOpen() {
return ((getDelegateSession() != null) && getDelegateSession().isOpen());
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
getDelegateSession().getBasicRemote().sendText(message.getPayload(), message.isLast());
}
@Override
protected void sendBinaryMessage(BinaryMessage message) throws IOException {
getDelegateSession().getBasicRemote().sendBinary(message.getPayload(), message.isLast());
}
@Override
protected void closeInternal(CloseStatus status) throws IOException {
getDelegateSession().close(new CloseReason(CloseCodes.getCloseCode(status.getCode()), status.getReason()));
}
}

View File

@ -1,145 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
import java.io.IOException;
import java.net.URI;
import java.security.Principal;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* Adapts a standard {@link javax.websocket.Session} to {@link WebSocketSession}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StandardWebSocketSessionAdapter extends AbstractWebSocketSesssionAdapter<javax.websocket.Session> {
private javax.websocket.Session session;
private URI uri;
private String remoteHostName;
private String remoteAddress;
private String protocol;
@Override
public void initSession(javax.websocket.Session session) {
Assert.notNull(session, "session must not be null");
this.session = session;
if (this.protocol == null) {
if (StringUtils.hasText(session.getNegotiatedSubprotocol())) {
this.protocol = session.getNegotiatedSubprotocol();
}
}
}
@Override
public String getId() {
return this.session.getId();
}
@Override
public URI getUri() {
return this.uri;
}
@Override
public void setUri(URI uri) {
this.uri = uri;
}
@Override
public boolean isSecure() {
return this.session.isSecure();
}
@Override
public Principal getPrincipal() {
return this.session.getUserPrincipal();
}
@Override
public void setPrincipal(Principal principal) {
// ignore
}
@Override
public String getRemoteHostName() {
return this.remoteHostName;
}
@Override
public void setRemoteHostName(String name) {
this.remoteHostName = name;
}
@Override
public String getRemoteAddress() {
return this.remoteAddress;
}
@Override
public void setRemoteAddress(String address) {
this.remoteAddress = address;
}
@Override
public String getAcceptedProtocol() {
return this.protocol;
}
@Override
public void setAcceptedProtocol(String protocol) {
this.protocol = protocol;
}
@Override
public boolean isOpen() {
return this.session.isOpen();
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
this.session.getBasicRemote().sendText(message.getPayload(), message.isLast());
}
@Override
protected void sendBinaryMessage(BinaryMessage message) throws IOException {
this.session.getBasicRemote().sendBinary(message.getPayload(), message.isLast());
}
@Override
protected void closeInternal(CloseStatus status) throws IOException {
this.session.close(new CloseReason(CloseCodes.getCloseCode(status.getCode()), status.getReason()));
}
}

View File

@ -73,6 +73,9 @@ public abstract class AbstractWebSocketClient implements WebSocketClient {
Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
Assert.notNull(uri, "uri must not be null");
String scheme = uri.getScheme();
Assert.isTrue(((scheme != null) && ("ws".equals(scheme) || "wss".equals(scheme))), "Invalid scheme: " + scheme);
if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + uri);
}

View File

@ -16,12 +16,10 @@
package org.springframework.web.socket.client;
import java.util.ArrayList;
import java.util.List;
import org.springframework.context.SmartLifecycle;
import org.springframework.http.HttpHeaders;
import org.springframework.util.CollectionUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator;
@ -43,9 +41,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
private WebSocketSession webSocketSession;
private final List<String> protocols = new ArrayList<String>();
private HttpHeaders headers;
private HttpHeaders headers = new HttpHeaders();
private final boolean syncClientLifecycle;
@ -76,24 +72,36 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
* any.
*/
public void setSubProtocols(List<String> protocols) {
this.protocols.clear();
if (!CollectionUtils.isEmpty(protocols)) {
this.protocols.addAll(protocols);
}
this.headers.setSecWebSocketProtocol(protocols);
}
/**
* Return the configured sub-protocols to use.
*/
public List<String> getSubProtocols() {
return this.protocols;
return this.headers.getSecWebSocketProtocol();
}
/**
* Set the origin to use.
*/
public void setOrigin(String origin) {
this.headers.setOrigin(origin);
}
/**
* @return the configured origin.
*/
public String getOrigin() {
return this.headers.getOrigin();
}
/**
* Provide default headers to add to the WebSocket handshake request.
*/
public void setHeaders(HttpHeaders headers) {
this.headers = headers;
this.headers.clear();
this.headers.putAll(headers);
}
/**
@ -122,14 +130,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
@Override
protected void openConnection() throws Exception {
HttpHeaders headers = new HttpHeaders();
if (this.headers != null) {
headers.putAll(this.headers);
}
headers.setSecWebSocketProtocol(this.protocols);
this.webSocketSession = this.client.doHandshake(this.webSocketHandler, headers, getUri());
this.webSocketSession = this.client.doHandshake(this.webSocketHandler, this.headers, getUri());
}
@Override

View File

@ -16,8 +16,12 @@
package org.springframework.web.socket.client.endpoint;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.UnknownHostException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import javax.websocket.ClientEndpointConfig;
@ -31,8 +35,8 @@ import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.StandardEndpointAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketSessionAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketSession;
import org.springframework.web.socket.client.AbstractWebSocketClient;
import org.springframework.web.socket.client.WebSocketConnectFailureException;
@ -60,19 +64,21 @@ public class StandardWebSocketClient extends AbstractWebSocketClient {
@Override
protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler,
HttpHeaders httpHeaders, URI uri, List<String> protocols) throws WebSocketConnectFailureException {
HttpHeaders headers, URI uri, List<String> protocols) throws WebSocketConnectFailureException {
StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter();
session.setUri(uri);
session.setRemoteHostName(uri.getHost());
int port = getPort(uri);
InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port);
InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port);
StandardWebSocketSession session = new StandardWebSocketSession(headers, localAddress, remoteAddress);
ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create();
configBuidler.configurator(new StandardWebSocketClientConfigurator(httpHeaders));
configBuidler.configurator(new StandardWebSocketClientConfigurator(headers));
configBuidler.preferredSubprotocols(protocols);
try {
// TODO: do not block
Endpoint endpoint = new StandardEndpointAdapter(webSocketHandler, session);
Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
this.webSocketContainer.connectToServer(endpoint, configBuidler.build(), uri);
return session;
@ -82,21 +88,38 @@ public class StandardWebSocketClient extends AbstractWebSocketClient {
}
}
private InetAddress getLocalHost() {
try {
return InetAddress.getLocalHost();
}
catch (UnknownHostException e) {
return InetAddress.getLoopbackAddress();
}
}
private int getPort(URI uri) {
if (uri.getPort() == -1) {
String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH);
return "wss".equals(scheme) ? 443 : 80;
}
return uri.getPort();
}
private class StandardWebSocketClientConfigurator extends Configurator {
private final HttpHeaders httpHeaders;
private final HttpHeaders headers;
public StandardWebSocketClientConfigurator(HttpHeaders httpHeaders) {
this.httpHeaders = httpHeaders;
public StandardWebSocketClientConfigurator(HttpHeaders headers) {
this.headers = headers;
}
@Override
public void beforeRequest(Map<String, List<String>> headers) {
headers.putAll(this.httpHeaders);
public void beforeRequest(Map<String, List<String>> requestHeaders) {
requestHeaders.putAll(this.headers);
if (logger.isDebugEnabled()) {
logger.debug("Handshake request headers: " + headers);
logger.debug("Handshake request headers: " + requestHeaders);
}
}
@Override

View File

@ -24,8 +24,8 @@ import org.springframework.context.SmartLifecycle;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSession;
import org.springframework.web.socket.client.AbstractWebSocketClient;
import org.springframework.web.socket.client.WebSocketConnectFailureException;
import org.springframework.web.util.UriComponents;
@ -130,7 +130,7 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma
}
@Override
public WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, HttpHeaders headers,
public WebSocketSession doHandshakeInternal(WebSocketHandler wsHandler, HttpHeaders headers,
URI uri, List<String> protocols) throws WebSocketConnectFailureException {
ClientUpgradeRequest request = new ClientUpgradeRequest();
@ -140,16 +140,13 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma
request.setHeader(header, headers.get(header));
}
JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter();
session.setUri(uri);
session.setRemoteHostName(uri.getHost());
JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler, session);
JettyWebSocketSession wsSession = new JettyWebSocketSession(null);
JettyWebSocketHandlerAdapter listener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession);
try {
// TODO: do not block
this.client.connect(listener, uri, request).get();
return session;
return wsSession;
}
catch (Exception e) {
throw new WebSocketConnectFailureException("Failed to connect to " + uri, e);

View File

@ -201,10 +201,14 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
protected String selectProtocol(List<String> requestedProtocols) {
if (requestedProtocols != null) {
if (logger.isDebugEnabled()) {
logger.debug("Requested sub-protocol(s): " + requestedProtocols
+ ", supported sub-protocol(s): " + this.supportedProtocols);
}
for (String protocol : requestedProtocols) {
if (this.supportedProtocols.contains(protocol.toLowerCase())) {
if (logger.isDebugEnabled()) {
logger.debug("Selected sub-protocol '" + protocol + "'");
logger.debug("Selected sub-protocol: '" + protocol + "'");
}
return protocol;
}

View File

@ -17,16 +17,18 @@
package org.springframework.web.socket.server.support;
import java.io.IOException;
import java.net.InetSocketAddress;
import javax.websocket.Endpoint;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.StandardEndpointAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketSessionAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketSession;
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
@ -40,17 +42,19 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS
protected final Log logger = LogFactory.getLog(getClass());
private final ServerWebSocketSessionInitializer wsSessionInitializer = new ServerWebSocketSessionInitializer();
@Override
public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
String protocol, WebSocketHandler handler) throws IOException, HandshakeFailureException {
String acceptedProtocol, WebSocketHandler wsHandler) throws IOException, HandshakeFailureException {
StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter();
this.wsSessionInitializer.initialize(request, response, protocol, session);
StandardEndpointAdapter endpoint = new StandardEndpointAdapter(handler, session);
upgradeInternal(request, response, protocol, endpoint);
HttpHeaders headers = request.getHeaders();
InetSocketAddress localAddress = request.getLocalAddress();
InetSocketAddress remoteAddress = request.getRemoteAddress();
StandardWebSocketSession wsSession = new StandardWebSocketSession(headers, localAddress, remoteAddress);
StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, wsSession);
upgradeInternal(request, response, acceptedProtocol, endpoint);
}
protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,

View File

@ -33,8 +33,8 @@ import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSession;
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
@ -59,8 +59,6 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
private WebSocketServerFactory factory;
private final ServerWebSocketSessionInitializer wsSessionInitializer = new ServerWebSocketSessionInitializer();
public JettyRequestUpgradeStrategy() {
this.factory = new WebSocketServerFactory();
@ -87,7 +85,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
@Override
public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
String protocol, WebSocketHandler webSocketHandler) throws IOException {
String protocol, WebSocketHandler wsHandler) throws IOException {
Assert.isInstanceOf(ServletServerHttpRequest.class, request);
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
@ -100,14 +98,13 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
throw new HandshakeFailureException("Not a WebSocket request");
}
JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter();
this.wsSessionInitializer.initialize(request, response, protocol, session);
JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler, session);
JettyWebSocketSession wsSession = new JettyWebSocketSession(request.getPrincipal());
JettyWebSocketHandlerAdapter wsListener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession);
servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, listener);
servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, wsListener);
if (!this.factory.acceptWebSocket(servletRequest, servletResponse)) {
// should never happen
// should not happen
throw new HandshakeFailureException("WebSocket request not accepted by Jetty");
}
}

View File

@ -1,43 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.support;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.ConfigurableWebSocketSession;
/**
* Copies information from the handshake HTTP request and response to a given
* {@link WebSocketSession}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ServerWebSocketSessionInitializer {
public void initialize(ServerHttpRequest request, ServerHttpResponse response,
String protocol, ConfigurableWebSocketSession session) {
session.setUri(request.getURI());
session.setRemoteHostName(request.getRemoteHostName());
session.setRemoteAddress(request.getRemoteAddress());
session.setPrincipal(request.getPrincipal());
session.setAcceptedProtocol(protocol);
}
}

View File

@ -52,7 +52,7 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg
@Override
public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
String selectedProtocol, Endpoint endpoint) throws IOException {
String acceptedProtocol, Endpoint endpoint) throws IOException {
Assert.isTrue(request instanceof ServletServerHttpRequest);
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
@ -82,7 +82,7 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg
ServerEndpointConfig endpointConfig = new ServerEndpointRegistration("/shouldntmatter", endpoint);
upgradeHandler.preInit(endpoint, endpointConfig, serverContainer, webSocketRequest,
selectedProtocol, Collections.<String, String> emptyMap(), servletRequest.isSecure());
acceptedProtocol, Collections.<String, String> emptyMap(), servletRequest.isSecure());
}
}

View File

@ -47,9 +47,6 @@ public abstract class AbstractHttpReceivingTransportHandler
public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, WebSocketSession wsSession) throws SockJsException {
// TODO: check "Sec-WebSocket-Protocol" header
// https://github.com/sockjs/sockjs-client/issues/130
Assert.notNull(wsSession, "No session");
AbstractHttpSockJsSession sockJsSession = (AbstractHttpSockJsSession) wsSession;

View File

@ -43,10 +43,14 @@ public abstract class AbstractHttpSendingTransportHandler extends TransportHandl
public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, WebSocketSession wsSession) throws SockJsException {
AbstractHttpSockJsSession sockJsSession = (AbstractHttpSockJsSession) wsSession;
String protocol = null; // TODO: https://github.com/sockjs/sockjs-client/issues/130
sockJsSession.setAcceptedProtocol(protocol);
// Set content type before writing
response.getHeaders().setContentType(getContentType());
AbstractHttpSockJsSession sockJsSession = (AbstractHttpSockJsSession) wsSession;
handleRequestInternal(request, response, sockJsSession);
}

View File

@ -42,7 +42,6 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.support.ServerWebSocketSessionInitializer;
import org.springframework.web.socket.sockjs.SockJsException;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.support.AbstractSockJsService;
@ -77,8 +76,6 @@ public class DefaultSockJsService extends AbstractSockJsService {
private final Map<String, AbstractSockJsSession> sessions = new ConcurrentHashMap<String, AbstractSockJsSession>();
private final ServerWebSocketSessionInitializer sessionInitializer = new ServerWebSocketSessionInitializer();
private ScheduledFuture sessionCleanupTask;
@ -279,8 +276,6 @@ public class DefaultSockJsService extends AbstractSockJsService {
}
logger.debug("Creating new session with session id \"" + sessionId + "\"");
session = sessionFactory.createSession(sessionId, handler);
String protocol = null; // TODO: https://github.com/sockjs/sockjs-client/issues/130
this.sessionInitializer.initialize(request, response, protocol, session);
this.sessions.put(sessionId, session);
return session;
}

View File

@ -69,7 +69,7 @@ public class SockJsWebSocketHandler extends TextWebSocketHandlerAdapter {
@Override
public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception {
Assert.isTrue(this.sessionCount.compareAndSet(0, 1), "Unexpected connection");
this.sockJsSession.initWebSocketSession(wsSession);
this.sockJsSession.afterSessionInitialized(wsSession);
}
@Override

View File

@ -17,9 +17,12 @@
package org.springframework.web.socket.sockjs.transport.session;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.Principal;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpAsyncRequestControl;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
@ -51,12 +54,56 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession {
private String protocol;
private HttpHeaders handshakeHeaders;
public AbstractHttpSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) {
super(sessionId, config, handler);
private Principal principal;
private InetSocketAddress localAddress;
private InetSocketAddress remoteAddress;
public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler) {
super(id, config, wsHandler);
}
@Override
public HttpHeaders getHandshakeHeaders() {
return this.handshakeHeaders;
}
protected void setHandshakeHeaders(HttpHeaders handshakeHeaders) {
this.handshakeHeaders = handshakeHeaders;
}
@Override
public Principal getPrincipal() {
return this.principal;
}
protected void setPrincipal(Principal principal) {
this.principal = principal;
}
@Override
public InetSocketAddress getLocalAddress() {
return this.localAddress;
}
protected void setLocalAddress(InetSocketAddress localAddress) {
this.localAddress = localAddress;
}
@Override
public InetSocketAddress getRemoteAddress() {
return this.remoteAddress;
}
protected void setRemoteAddress(InetSocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
/**
* Unlike WebSocket where sub-protocol negotiation is part of the
* initial handshake, in HTTP transports the same negotiation must
@ -87,6 +134,12 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession {
tryCloseWithSockJsTransportError(t, CloseStatus.SERVER_ERROR);
throw new SockJsTransportFailureException("Failed to send \"open\" frame", getId(), t);
}
this.handshakeHeaders = request.getHeaders();
this.principal = request.getPrincipal();
this.localAddress = request.getLocalAddress();
this.remoteAddress = request.getRemoteAddress();
try {
delegateConnectionEstablished();
}

View File

@ -35,7 +35,6 @@ import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.ConfigurableWebSocketSession;
import org.springframework.web.socket.sockjs.SockJsMessageDeliveryException;
import org.springframework.web.socket.sockjs.SockJsTransportFailureException;
import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
@ -46,7 +45,7 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractSockJsSession implements ConfigurableWebSocketSession {
public abstract class AbstractSockJsSession implements WebSocketSession {
protected final Log logger = LogFactory.getLog(getClass());
@ -97,46 +96,6 @@ public abstract class AbstractSockJsSession implements ConfigurableWebSocketSess
return this.uri;
}
@Override
public void setUri(URI uri) {
this.uri = uri;
}
@Override
public boolean isSecure() {
return "wss".equals(this.uri.getSchemeSpecificPart());
}
@Override
public String getRemoteHostName() {
return this.remoteHostName;
}
@Override
public void setRemoteHostName(String remoteHostName) {
this.remoteHostName = remoteHostName;
}
@Override
public String getRemoteAddress() {
return this.remoteAddress;
}
@Override
public void setRemoteAddress(String remoteAddress) {
this.remoteAddress = remoteAddress;
}
@Override
public Principal getPrincipal() {
return this.principal;
}
@Override
public void setPrincipal(Principal principal) {
this.principal = principal;
}
public SockJsServiceConfig getSockJsServiceConfig() {
return this.sockJsServiceConfig;
}

View File

@ -17,12 +17,17 @@
package org.springframework.web.socket.sockjs.transport.session;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.Principal;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.DelegatingWebSocketSession;
import org.springframework.web.socket.sockjs.SockJsTransportFailureException;
import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
import org.springframework.web.socket.sockjs.support.frame.SockJsMessageCodec;
@ -33,47 +38,69 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsMessageCodec;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class WebSocketServerSockJsSession extends AbstractSockJsSession {
public class WebSocketServerSockJsSession extends AbstractSockJsSession
implements DelegatingWebSocketSession<WebSocketSession> {
private WebSocketSession webSocketSession;
private WebSocketSession wsSession;
public WebSocketServerSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) {
super(sessionId, config, handler);
public WebSocketServerSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler) {
super(id, config, wsHandler);
}
@Override
public HttpHeaders getHandshakeHeaders() {
checkDelegateSessionInitialized();
return this.wsSession.getHandshakeHeaders();
}
@Override
public Principal getPrincipal() {
checkDelegateSessionInitialized();
return this.wsSession.getPrincipal();
}
@Override
public InetSocketAddress getLocalAddress() {
checkDelegateSessionInitialized();
return this.wsSession.getLocalAddress();
}
@Override
public InetSocketAddress getRemoteAddress() {
checkDelegateSessionInitialized();
return this.wsSession.getRemoteAddress();
}
@Override
public String getAcceptedProtocol() {
if (this.webSocketSession == null) {
logger.warn("getAcceptedProtocol() invoked before WebSocketSession has been initialized.");
return null;
}
return this.webSocketSession.getAcceptedProtocol();
checkDelegateSessionInitialized();
return this.wsSession.getAcceptedProtocol();
}
private void checkDelegateSessionInitialized() {
Assert.state(this.wsSession != null, "WebSocketSession not yet initialized");
}
@Override
public void setAcceptedProtocol(String protocol) {
// ignore, webSocketSession should have it
}
public void initWebSocketSession(WebSocketSession session) throws Exception {
this.webSocketSession = session;
public void afterSessionInitialized(WebSocketSession session) {
this.wsSession = session;
try {
TextMessage message = new TextMessage(SockJsFrame.openFrame().getContent());
this.webSocketSession.sendMessage(message);
this.wsSession.sendMessage(message);
scheduleHeartbeat();
delegateConnectionEstablished();
}
catch (IOException ex) {
catch (Exception ex) {
tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR);
return;
}
scheduleHeartbeat();
delegateConnectionEstablished();
}
@Override
public boolean isActive() {
return ((this.webSocketSession != null) && this.webSocketSession.isOpen());
return ((this.wsSession != null) && this.wsSession.isOpen());
}
public void handleMessage(TextMessage message, WebSocketSession wsSession) throws Exception {
@ -109,13 +136,13 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession {
logger.trace("Write " + frame);
}
TextMessage message = new TextMessage(frame.getContent());
this.webSocketSession.sendMessage(message);
this.wsSession.sendMessage(message);
}
@Override
protected void disconnect(CloseStatus status) throws IOException {
if (this.webSocketSession != null) {
this.webSocketSession.close(status);
if (this.wsSession != null) {
this.wsSession.close(status);
}
}

View File

@ -25,17 +25,17 @@ import org.springframework.web.socket.WebSocketHandler;
import static org.mockito.Mockito.*;
/**
* Test fixture for {@link JettyWebSocketListenerAdapter}.
* Test fixture for {@link JettyWebSocketHandlerAdapter}.
*
* @author Rossen Stoyanchev
*/
public class JettyWebSocketListenerAdapterTests {
public class JettyWebSocketHandlerAdapterTests {
private JettyWebSocketListenerAdapter adapter;
private JettyWebSocketHandlerAdapter adapter;
private WebSocketHandler webSocketHandler;
private JettyWebSocketSessionAdapter webSocketSession;
private JettyWebSocketSession webSocketSession;
private Session session;
@ -44,8 +44,8 @@ public class JettyWebSocketListenerAdapterTests {
public void setup() {
this.session = mock(Session.class);
this.webSocketHandler = mock(WebSocketHandler.class);
this.webSocketSession = new JettyWebSocketSessionAdapter();
this.adapter = new JettyWebSocketListenerAdapter(this.webSocketHandler, this.webSocketSession);
this.webSocketSession = new JettyWebSocketSession(null);
this.adapter = new JettyWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession);
}
@Test

View File

@ -31,17 +31,17 @@ import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*;
/**
* Test fixture for {@link StandardEndpointAdapter}.
* Test fixture for {@link StandardWebSocketHandlerAdapter}.
*
* @author Rossen Stoyanchev
*/
public class StandardEndpointAdapterTests {
public class StandardWebSocketHandlerAdapterTests {
private StandardEndpointAdapter adapter;
private StandardWebSocketHandlerAdapter adapter;
private WebSocketHandler webSocketHandler;
private StandardWebSocketSessionAdapter webSocketSession;
private StandardWebSocketSession webSocketSession;
private Session session;
@ -50,8 +50,8 @@ public class StandardEndpointAdapterTests {
public void setup() {
this.session = mock(Session.class);
this.webSocketHandler = mock(WebSocketHandler.class);
this.webSocketSession = new StandardWebSocketSessionAdapter();
this.adapter = new StandardEndpointAdapter(this.webSocketHandler, this.webSocketSession);
this.webSocketSession = new StandardWebSocketSession(null, null, null);
this.adapter = new StandardWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession);
}
@Test

View File

@ -27,12 +27,12 @@ import javax.websocket.ClientEndpointConfig;
import javax.websocket.Endpoint;
import javax.websocket.WebSocketContainer;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.StandardEndpointAdapter;
import org.springframework.web.socket.adapter.WebSocketHandlerAdapter;
import static org.junit.Assert.*;
@ -45,40 +45,92 @@ import static org.mockito.Mockito.*;
*/
public class StandardWebSocketClientTests {
private StandardWebSocketClient wsClient;
private WebSocketContainer wsContainer;
private WebSocketHandler wsHandler;
private HttpHeaders headers;
@Before
public void setup() {
this.headers = new HttpHeaders();
this.wsHandler = new WebSocketHandlerAdapter();
this.wsContainer = mock(WebSocketContainer.class);
this.wsClient = new StandardWebSocketClient(this.wsContainer);
}
@Test
public void doHandshake() throws Exception {
public void localAddress() throws Exception {
URI uri = new URI("ws://example.com/abc");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
assertNotNull(session.getLocalAddress());
assertEquals(80, session.getLocalAddress().getPort());
}
@Test
public void localAddressWss() throws Exception {
URI uri = new URI("wss://example.com/abc");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
assertNotNull(session.getLocalAddress());
assertEquals(443, session.getLocalAddress().getPort());
}
@Test(expected=IllegalArgumentException.class)
public void localAddressNoScheme() throws Exception {
URI uri = new URI("example.com/abc");
this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
}
@Test
public void remoteAddress() throws Exception {
URI uri = new URI("wss://example.com/abc");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
assertNotNull(session.getRemoteAddress());
assertEquals("example.com", session.getRemoteAddress().getHostName());
assertEquals(443, session.getLocalAddress().getPort());
}
@Test
public void headersWebSocketSession() throws Exception {
URI uri = new URI("ws://example.com/abc");
List<String> subprotocols = Arrays.asList("abc");
List<String> protocols = Arrays.asList("abc");
this.headers.setSecWebSocketProtocol(protocols);
this.headers.add("foo", "bar");
HttpHeaders headers = new HttpHeaders();
headers.setSecWebSocketProtocol(subprotocols);
headers.add("foo", "bar");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
WebSocketHandler handler = new WebSocketHandlerAdapter();
WebSocketContainer webSocketContainer = mock(WebSocketContainer.class);
StandardWebSocketClient client = new StandardWebSocketClient(webSocketContainer);
WebSocketSession session = client.doHandshake(handler, headers, uri);
assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), session.getHandshakeHeaders());
}
ArgumentCaptor<Endpoint> endpointArg = ArgumentCaptor.forClass(Endpoint.class);
ArgumentCaptor<ClientEndpointConfig> configArg = ArgumentCaptor.forClass(ClientEndpointConfig.class);
ArgumentCaptor<URI> uriArg = ArgumentCaptor.forClass(URI.class);
@Test
public void headersClientEndpointConfigurator() throws Exception {
verify(webSocketContainer).connectToServer(endpointArg.capture(), configArg.capture(), uriArg.capture());
URI uri = new URI("ws://example.com/abc");
List<String> protocols = Arrays.asList("abc");
this.headers.setSecWebSocketProtocol(protocols);
this.headers.add("foo", "bar");
assertNotNull(endpointArg.getValue());
assertEquals(StandardEndpointAdapter.class, endpointArg.getValue().getClass());
this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
ClientEndpointConfig config = configArg.getValue();
assertEquals(subprotocols, config.getPreferredSubprotocols());
ArgumentCaptor<Endpoint> arg1 = ArgumentCaptor.forClass(Endpoint.class);
ArgumentCaptor<ClientEndpointConfig> arg2 = ArgumentCaptor.forClass(ClientEndpointConfig.class);
ArgumentCaptor<URI> arg3 = ArgumentCaptor.forClass(URI.class);
verify(this.wsContainer).connectToServer(arg1.capture(), arg2.capture(), arg3.capture());
ClientEndpointConfig endpointConfig = arg2.getValue();
assertEquals(protocols, endpointConfig.getPreferredSubprotocols());
Map<String, List<String>> map = new HashMap<>();
config.getConfigurator().beforeRequest(map);
endpointConfig.getConfigurator().beforeRequest(map);
assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), map);
assertEquals(uri, uriArg.getValue());
assertEquals(uri, session.getUri());
assertEquals("example.com", session.getRemoteHostName());
}
}

View File

@ -33,8 +33,8 @@ import org.springframework.util.CollectionUtils;
import org.springframework.util.SocketUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSession;
import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
import static org.junit.Assert.*;
@ -113,8 +113,8 @@ public class JettyWebSocketClientTests {
resp.setAcceptedSubProtocol(req.getSubProtocols().get(0));
}
JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter();
return new JettyWebSocketListenerAdapter(webSocketHandler, session);
JettyWebSocketSession session = new JettyWebSocketSession(null);
return new JettyWebSocketHandlerAdapter(webSocketHandler, session);
}
});
}

View File

@ -17,9 +17,12 @@
package org.springframework.web.socket.sockjs.transport.session;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.Principal;
import java.util.ArrayList;
import java.util.List;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
@ -29,6 +32,14 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
*/
public class TestSockJsSession extends AbstractSockJsSession {
private HttpHeaders headers;
private Principal principal;
private InetSocketAddress localAddress;
private InetSocketAddress remoteAddress;
private boolean active;
private final List<SockJsFrame> sockJsFrames = new ArrayList<>();
@ -48,12 +59,76 @@ public class TestSockJsSession extends AbstractSockJsSession {
super(sessionId, config, handler);
}
@Override
public HttpHeaders getHandshakeHeaders() {
return this.headers;
}
/**
* @return the headers
*/
public HttpHeaders getHeaders() {
return this.headers;
}
/**
* @param headers the headers to set
*/
public void setHeaders(HttpHeaders headers) {
this.headers = headers;
}
/**
* @return the principal
*/
@Override
public Principal getPrincipal() {
return this.principal;
}
/**
* @param principal the principal to set
*/
public void setPrincipal(Principal principal) {
this.principal = principal;
}
/**
* @return the localAddress
*/
@Override
public InetSocketAddress getLocalAddress() {
return this.localAddress;
}
/**
* @param remoteAddress the remoteAddress to set
*/
public void setLocalAddress(InetSocketAddress localAddress) {
this.localAddress = localAddress;
}
/**
* @return the remoteAddress
*/
@Override
public InetSocketAddress getRemoteAddress() {
return this.remoteAddress;
}
/**
* @param remoteAddress the remoteAddress to set
*/
public void setRemoteAddress(InetSocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
@Override
public String getAcceptedProtocol() {
return this.subProtocol;
}
@Override
public void setAcceptedProtocol(String protocol) {
this.subProtocol = protocol;
}

View File

@ -27,7 +27,6 @@ import org.junit.Test;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession;
import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSessionTests.TestWebSocketServerSockJsSession;
import org.springframework.web.socket.support.TestWebSocketSession;
@ -61,7 +60,7 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession
public void isActive() throws Exception {
assertFalse(this.session.isActive());
this.session.initWebSocketSession(this.webSocketSession);
this.session.afterSessionInitialized(this.webSocketSession);
assertTrue(this.session.isActive());
this.webSocketSession.setOpen(false);
@ -69,9 +68,9 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession
}
@Test
public void initWebSocketSession() throws Exception {
public void afterSessionInitialized() throws Exception {
this.session.initWebSocketSession(this.webSocketSession);
this.session.afterSessionInitialized(this.webSocketSession);
assertEquals("Open frame not sent",
Collections.singletonList(new TextMessage("o")), this.webSocketSession.getSentMessages());
@ -110,7 +109,7 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession
@Test
public void sendMessageInternal() throws Exception {
this.session.initWebSocketSession(this.webSocketSession);
this.session.afterSessionInitialized(this.webSocketSession);
this.session.sendMessageInternal("x");
assertEquals(Arrays.asList(new TextMessage("o"), new TextMessage("a[\"x\"]")),
@ -122,7 +121,7 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession
@Test
public void disconnect() throws Exception {
this.session.initWebSocketSession(this.webSocketSession);
this.session.afterSessionInitialized(this.webSocketSession);
this.session.close(CloseStatus.NOT_ACCEPTABLE);
assertEquals(CloseStatus.NOT_ACCEPTABLE, this.webSocketSession.getCloseStatus());

View File

@ -17,11 +17,13 @@
package org.springframework.web.socket.support;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.util.ArrayList;
import java.util.List;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
@ -37,13 +39,11 @@ public class TestWebSocketSession implements WebSocketSession {
private URI uri;
private boolean secure;
private Principal principal;
private String remoteHostName;
private InetSocketAddress localAddress;
private String remoteAddress;
private InetSocketAddress remoteAddress;
private String protocol;
@ -53,6 +53,8 @@ public class TestWebSocketSession implements WebSocketSession {
private CloseStatus status;
private HttpHeaders headers;
/**
* @return the id
@ -84,19 +86,24 @@ public class TestWebSocketSession implements WebSocketSession {
this.uri = uri;
}
/**
* @return the secure
*/
@Override
public boolean isSecure() {
return this.secure;
public HttpHeaders getHandshakeHeaders() {
return this.headers;
}
/**
* @param secure the secure to set
* @return the headers
*/
public void setSecure(boolean secure) {
this.secure = secure;
public HttpHeaders getHeaders() {
return this.headers;
}
/**
* @param headers the headers to set
*/
public void setHeaders(HttpHeaders headers) {
this.headers = headers;
}
/**
@ -115,32 +122,32 @@ public class TestWebSocketSession implements WebSocketSession {
}
/**
* @return the remoteHostName
* @return the localAddress
*/
@Override
public String getRemoteHostName() {
return this.remoteHostName;
public InetSocketAddress getLocalAddress() {
return this.localAddress;
}
/**
* @param remoteHostName the remoteHostName to set
* @param remoteAddress the remoteAddress to set
*/
public void setRemoteHostName(String remoteHostName) {
this.remoteHostName = remoteHostName;
public void setLocalAddress(InetSocketAddress localAddress) {
this.localAddress = localAddress;
}
/**
* @return the remoteAddress
*/
@Override
public String getRemoteAddress() {
public InetSocketAddress getRemoteAddress() {
return this.remoteAddress;
}
/**
* @param remoteAddress the remoteAddress to set
*/
public void setRemoteAddress(String remoteAddress) {
public void setRemoteAddress(InetSocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}