diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java index 42125d78d23..3541b1b27be 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java @@ -30,6 +30,7 @@ import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessagingException; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketMessage; @@ -133,7 +134,7 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHan protected final SubProtocolHandler getProtocolHandler(WebSocketSession session) { SubProtocolHandler handler; String protocol = session.getAcceptedProtocol(); - if (protocol != null) { + if (!StringUtils.isEmpty(protocol)) { handler = this.protocolHandlers.get(protocol); Assert.state(handler != null, "No handler for sub-protocol '" + protocol + "', handlers=" + this.protocolHandlers); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandlerTests.java index 8654aeaddcc..aac62bf1352 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandlerTests.java @@ -91,7 +91,18 @@ public class SubProtocolWebSocketHandlerTests { } @Test - public void noSubProtocol() throws Exception { + public void nullSubProtocol() throws Exception { + this.webSocketHandler.setDefaultProtocolHandler(defaultHandler); + this.webSocketHandler.afterConnectionEstablished(session); + + verify(this.defaultHandler).afterSessionStarted(session, this.channel); + verify(this.stompHandler, times(0)).afterSessionStarted(session, this.channel); + verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.channel); + } + + @Test + public void emptySubProtocol() throws Exception { + this.session.setAcceptedProtocol(""); this.webSocketHandler.setDefaultProtocolHandler(defaultHandler); this.webSocketHandler.afterConnectionEstablished(session); diff --git a/spring-web/src/main/java/org/springframework/http/server/ServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/ServerHttpRequest.java index 7321d071048..39ac8d60317 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServerHttpRequest.java @@ -16,6 +16,7 @@ package org.springframework.http.server; +import java.net.InetSocketAddress; import java.security.Principal; import java.util.Map; @@ -51,14 +52,14 @@ public interface ServerHttpRequest extends HttpRequest, HttpInputMessage { Principal getPrincipal(); /** - * Return the host name of the endpoint on the other end. + * Return the address on which the request was received. */ - String getRemoteHostName(); + InetSocketAddress getLocalAddress(); /** - * Return the IP address of the endpoint on the other end. + * Return the address of the remote client. */ - String getRemoteAddress(); + InetSocketAddress getRemoteAddress(); /** * Return a control that allows putting the request in asynchronous mode so the diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java index 22a8720083e..0d5ae5bfab1 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStreamWriter; import java.io.Writer; +import java.net.InetSocketAddress; import java.net.URI; import java.net.URISyntaxException; import java.net.URLEncoder; @@ -147,13 +148,13 @@ public class ServletServerHttpRequest implements ServerHttpRequest { } @Override - public String getRemoteHostName() { - return this.servletRequest.getRemoteHost(); + public InetSocketAddress getLocalAddress() { + return new InetSocketAddress(this.servletRequest.getLocalName(), this.servletRequest.getLocalPort()); } @Override - public String getRemoteAddress() { - return this.servletRequest.getRemoteAddr(); + public InetSocketAddress getRemoteAddress() { + return new InetSocketAddress(this.servletRequest.getRemoteHost(), this.servletRequest.getRemotePort()); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java index 16b59ed57ec..bf14730e9cb 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java @@ -17,9 +17,12 @@ package org.springframework.web.socket; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import org.springframework.http.HttpHeaders; + /** * A WebSocket session abstraction. Allows sending messages over a WebSocket connection * and closing it. @@ -29,6 +32,7 @@ import java.security.Principal; */ public interface WebSocketSession { + /** * Return a unique session identifier. */ @@ -40,9 +44,9 @@ public interface WebSocketSession { URI getUri(); /** - * Return whether the underlying socket is using a secure transport. + * Return the headers used in the handshake request. */ - boolean isSecure(); + HttpHeaders getHandshakeHeaders(); /** * Return a {@link java.security.Principal} instance containing the name of the @@ -52,17 +56,18 @@ public interface WebSocketSession { Principal getPrincipal(); /** - * Return the host name of the endpoint on the other end. + * Return the address on which the request was received. */ - String getRemoteHostName(); + InetSocketAddress getLocalAddress(); /** - * Return the IP address of the endpoint on the other end. + * Return the address of the remote client. */ - String getRemoteAddress(); + InetSocketAddress getRemoteAddress(); /** - * Return the negotiated sub-protocol or {@code null} if none was specified. + * Return the negotiated sub-protocol or {@code null} if none was specified or + * negotiated successfully. */ String getAcceptedProtocol(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssionAdapter.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java similarity index 77% rename from spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssionAdapter.java rename to spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java index 7aff7a2347f..2faffbfb13c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssionAdapter.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java @@ -32,19 +32,41 @@ import org.springframework.web.socket.WebSocketSession; * @author Rossen Stoyanchev * @since 4.0 */ -public abstract class AbstractWebSocketSesssionAdapter implements ConfigurableWebSocketSession { +public abstract class AbstractWebSocketSesssion implements DelegatingWebSocketSession { protected final Log logger = LogFactory.getLog(getClass()); + private T delegateSession; - public abstract void initSession(T session); + + /** + * @return the WebSocket session to delegate to + */ + public T getDelegateSession() { + return this.delegateSession; + } + + + @Override + public void afterSessionInitialized(T session) { + Assert.notNull(session, "session must not be null"); + this.delegateSession = session; + } + + protected final void checkDelegateSessionInitialized() { + Assert.state(this.delegateSession != null, "WebSocket session is not yet initialized"); + } @Override public final void sendMessage(WebSocketMessage message) throws IOException { + + checkDelegateSessionInitialized(); + Assert.isTrue(isOpen(), "Cannot send message after connection closed."); + if (logger.isTraceEnabled()) { logger.trace("Sending " + message + ", " + this); } - Assert.isTrue(isOpen(), "Cannot send message after connection closed."); + if (message instanceof TextMessage) { sendTextMessage((TextMessage) message); } @@ -60,13 +82,15 @@ public abstract class AbstractWebSocketSesssionAdapter implements Configurabl protected abstract void sendBinaryMessage(BinaryMessage message) throws IOException ; + @Override - public void close() throws IOException { + public final void close() throws IOException { close(CloseStatus.NORMAL); } @Override public final void close(CloseStatus status) throws IOException { + checkDelegateSessionInitialized(); if (logger.isDebugEnabled()) { logger.debug("Closing " + this); } @@ -75,6 +99,7 @@ public abstract class AbstractWebSocketSesssionAdapter implements Configurabl protected abstract void closeInternal(CloseStatus status) throws IOException; + @Override public String toString() { return "WebSocket session id=" + getId(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java similarity index 51% rename from spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java rename to spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java index 47cae03678d..73260ca668e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java @@ -16,34 +16,25 @@ package org.springframework.web.socket.adapter; -import java.net.URI; -import java.security.Principal; - import org.springframework.web.socket.WebSocketSession; -import org.springframework.web.socket.server.DefaultHandshakeHandler; + /** - * A WebSocketSession with configurable properties. + * A contract for {@link WebSocketSession} implementations that delegate to another + * WebSocket session (e.g. a native session). + * + * @param T the type of the delegate WebSocket session * * @author Rossen Stoyanchev * @since 4.0 */ -public interface ConfigurableWebSocketSession extends WebSocketSession { +public interface DelegatingWebSocketSession extends WebSocketSession { - void setUri(URI uri); - - void setRemoteHostName(String name); - - void setRemoteAddress(String address); - - void setPrincipal(Principal principal); /** - * Set the protocol accepted as part of the WebSocket handshake. This property can be - * used when the WebSocket handshake is performed through - * {@link DefaultHandshakeHandler} rather than the underlying WebSocket runtime, or - * when there is no WebSocket handshake (e.g. SockJS HTTP fallback options) + * Invoked when the delegate WebSocket session has been initialized. */ - void setAcceptedProtocol(String protocol); + void afterSessionInitialized(T session); + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketListenerAdapter.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapter.java similarity index 88% rename from spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketListenerAdapter.java rename to spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapter.java index df971ac4436..c1d80db84ee 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketListenerAdapter.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapter.java @@ -28,21 +28,22 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator; /** - * Adapts {@link WebSocketHandler} to the Jetty 9 {@link WebSocketListener}. + * Adapts {@link WebSocketHandler} to the Jetty 9 WebSocket API. * * @author Phillip Webb + * @author Rossen Stoyanchev * @since 4.0 */ -public class JettyWebSocketListenerAdapter implements WebSocketListener { +public class JettyWebSocketHandlerAdapter implements WebSocketListener { - private static final Log logger = LogFactory.getLog(JettyWebSocketListenerAdapter.class); + private static final Log logger = LogFactory.getLog(JettyWebSocketHandlerAdapter.class); private final WebSocketHandler webSocketHandler; - private final JettyWebSocketSessionAdapter wsSession; + private final JettyWebSocketSession wsSession; - public JettyWebSocketListenerAdapter(WebSocketHandler webSocketHandler, JettyWebSocketSessionAdapter wsSession) { + public JettyWebSocketHandlerAdapter(WebSocketHandler webSocketHandler, JettyWebSocketSession wsSession) { Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); Assert.notNull(wsSession, "wsSession must not be null"); this.webSocketHandler = webSocketHandler; @@ -52,8 +53,8 @@ public class JettyWebSocketListenerAdapter implements WebSocketListener { @Override public void onWebSocketConnect(Session session) { - this.wsSession.initSession(session); try { + this.wsSession.afterSessionInitialized(session); this.webSocketHandler.afterConnectionEstablished(this.wsSession); } catch (Throwable t) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java new file mode 100644 index 00000000000..708ac66bc1b --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.adapter; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.security.Principal; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.ObjectUtils; +import org.springframework.web.socket.BinaryMessage; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; + +/** + * A {@link WebSocketSession} for use with the Jetty 9 WebSocket API. + * + * @author Phillip Webb + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class JettyWebSocketSession extends AbstractWebSocketSesssion { + + private HttpHeaders headers; + + private final Principal principal; + + + /** + * Class constructor. + * + * @param principal the user associated with the session, or {@code null} + */ + public JettyWebSocketSession(Principal principal) { + this.principal = principal; + } + + + @Override + public String getId() { + checkDelegateSessionInitialized(); + return ObjectUtils.getIdentityHexString(getDelegateSession()); + } + + @Override + public URI getUri() { + checkDelegateSessionInitialized(); + return getDelegateSession().getUpgradeRequest().getRequestURI(); + } + + @Override + public HttpHeaders getHandshakeHeaders() { + checkDelegateSessionInitialized(); + if (this.headers == null) { + this.headers = new HttpHeaders(); + this.headers.putAll(getDelegateSession().getUpgradeRequest().getHeaders()); + this.headers = HttpHeaders.readOnlyHttpHeaders(headers); + } + return this.headers; + } + + @Override + public Principal getPrincipal() { + return this.principal; + } + + @Override + public InetSocketAddress getLocalAddress() { + checkDelegateSessionInitialized(); + return getDelegateSession().getLocalAddress(); + } + + @Override + public InetSocketAddress getRemoteAddress() { + checkDelegateSessionInitialized(); + return getDelegateSession().getRemoteAddress(); + } + + @Override + public String getAcceptedProtocol() { + checkDelegateSessionInitialized(); + return getDelegateSession().getUpgradeResponse().getAcceptedSubProtocol(); + } + + @Override + public boolean isOpen() { + return ((getDelegateSession() != null) && getDelegateSession().isOpen()); + } + + @Override + protected void sendTextMessage(TextMessage message) throws IOException { + getDelegateSession().getRemote().sendString(message.getPayload()); + } + + @Override + protected void sendBinaryMessage(BinaryMessage message) throws IOException { + getDelegateSession().getRemote().sendBytes(message.getPayload()); + } + + @Override + protected void closeInternal(CloseStatus status) throws IOException { + getDelegateSession().close(status.getCode(), status.getReason()); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java deleted file mode 100644 index 525f3de6fa9..00000000000 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.socket.adapter; - -import java.io.IOException; -import java.net.InetSocketAddress; -import java.net.URI; -import java.security.Principal; - -import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.UpgradeResponse; -import org.springframework.util.Assert; -import org.springframework.util.ObjectUtils; -import org.springframework.web.socket.BinaryMessage; -import org.springframework.web.socket.CloseStatus; -import org.springframework.web.socket.TextMessage; -import org.springframework.web.socket.WebSocketSession; - -/** - * Adapts a Jetty {@link org.eclipse.jetty.websocket.api.Session} to - * {@link WebSocketSession}. - * - * @author Phillip Webb - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class JettyWebSocketSessionAdapter - extends AbstractWebSocketSesssionAdapter { - - private Session session; - - private Principal principal; - - private String protocol; - - - @Override - public void initSession(Session session) { - Assert.notNull(session, "session must not be null"); - this.session = session; - - if (this.protocol == null) { - UpgradeResponse response = session.getUpgradeResponse(); - if ((response != null) && response.getAcceptedSubProtocol() != null) { - this.protocol = response.getAcceptedSubProtocol(); - } - } - } - - @Override - public String getId() { - return ObjectUtils.getIdentityHexString(this.session); - } - - @Override - public boolean isSecure() { - return this.session.isSecure(); - } - - @Override - public URI getUri() { - return this.session.getUpgradeRequest().getRequestURI(); - } - - @Override - public void setUri(URI uri) { - } - - @Override - public Principal getPrincipal() { - return this.principal; - } - - @Override - public void setPrincipal(Principal principal) { - this.principal = principal; - } - - @Override - public String getRemoteHostName() { - return this.session.getRemoteAddress().getHostName(); - } - - @Override - public void setRemoteHostName(String address) { - // ignore - } - - @Override - public String getRemoteAddress() { - InetSocketAddress address = this.session.getRemoteAddress(); - return address.isUnresolved() ? null : address.getAddress().getHostAddress(); - } - - @Override - public void setRemoteAddress(String address) { - // ignore - } - - @Override - public String getAcceptedProtocol() { - return this.protocol; - } - - @Override - public void setAcceptedProtocol(String protocol) { - this.protocol = protocol; - } - - @Override - public boolean isOpen() { - return this.session.isOpen(); - } - - @Override - protected void sendTextMessage(TextMessage message) throws IOException { - this.session.getRemote().sendString(message.getPayload()); - } - - @Override - protected void sendBinaryMessage(BinaryMessage message) throws IOException { - this.session.getRemote().sendBytes(message.getPayload()); - } - - @Override - protected void closeInternal(CloseStatus status) throws IOException { - this.session.close(status.getCode(), status.getReason()); - } - -} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardEndpointAdapter.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapter.java similarity index 90% rename from spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardEndpointAdapter.java rename to spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapter.java index 4bfd55d265e..ed046335d78 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardEndpointAdapter.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapter.java @@ -33,21 +33,21 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator; /** - * Adapts a {@link WebSocketHandler} to a standard {@link Endpoint}. + * Adapts a {@link WebSocketHandler} to the standard WebSocket for Java API. * * @author Rossen Stoyanchev * @since 4.0 */ -public class StandardEndpointAdapter extends Endpoint { +public class StandardWebSocketHandlerAdapter extends Endpoint { - private static final Log logger = LogFactory.getLog(StandardEndpointAdapter.class); + private static final Log logger = LogFactory.getLog(StandardWebSocketHandlerAdapter.class); private final WebSocketHandler handler; - private final StandardWebSocketSessionAdapter wsSession; + private final StandardWebSocketSession wsSession; - public StandardEndpointAdapter(WebSocketHandler handler, StandardWebSocketSessionAdapter wsSession) { + public StandardWebSocketHandlerAdapter(WebSocketHandler handler, StandardWebSocketSession wsSession) { Assert.notNull(handler, "handler must not be null"); Assert.notNull(wsSession, "wsSession must not be null"); this.handler = handler; @@ -58,7 +58,7 @@ public class StandardEndpointAdapter extends Endpoint { @Override public void onOpen(final javax.websocket.Session session, EndpointConfig config) { - this.wsSession.initSession(session); + this.wsSession.afterSessionInitialized(session); if (this.handler.supportsPartialMessages()) { session.addMessageHandler(new MessageHandler.Partial() { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java new file mode 100644 index 00000000000..8dd2c1fd38e --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java @@ -0,0 +1,123 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.adapter; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.security.Principal; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.StringUtils; +import org.springframework.web.socket.BinaryMessage; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; + +/** + * A {@link WebSocketSession} for use with the standard WebSocket for Java API. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class StandardWebSocketSession extends AbstractWebSocketSesssion { + + private final HttpHeaders headers; + + private final InetSocketAddress localAddress; + + private final InetSocketAddress remoteAddress; + + + /** + * Class constructor. + * + * @param handshakeHeaders the headers of the handshake request + */ + public StandardWebSocketSession(HttpHeaders handshakeHeaders, InetSocketAddress localAddress, + InetSocketAddress remoteAddress) { + + handshakeHeaders = (handshakeHeaders != null) ? handshakeHeaders : new HttpHeaders(); + this.headers = HttpHeaders.readOnlyHttpHeaders(handshakeHeaders); + this.localAddress = localAddress; + this.remoteAddress = remoteAddress; + } + + @Override + public String getId() { + checkDelegateSessionInitialized(); + return getDelegateSession().getId(); + } + + @Override + public URI getUri() { + checkDelegateSessionInitialized(); + return getDelegateSession().getRequestURI(); + } + + @Override + public HttpHeaders getHandshakeHeaders() { + return this.headers; + } + + @Override + public Principal getPrincipal() { + checkDelegateSessionInitialized(); + return getDelegateSession().getUserPrincipal(); + } + + @Override + public InetSocketAddress getLocalAddress() { + return this.localAddress; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return this.remoteAddress; + } + + @Override + public String getAcceptedProtocol() { + checkDelegateSessionInitialized(); + String protocol = getDelegateSession().getNegotiatedSubprotocol(); + return StringUtils.isEmpty(protocol)? null : protocol; + } + + @Override + public boolean isOpen() { + return ((getDelegateSession() != null) && getDelegateSession().isOpen()); + } + + @Override + protected void sendTextMessage(TextMessage message) throws IOException { + getDelegateSession().getBasicRemote().sendText(message.getPayload(), message.isLast()); + } + + @Override + protected void sendBinaryMessage(BinaryMessage message) throws IOException { + getDelegateSession().getBasicRemote().sendBinary(message.getPayload(), message.isLast()); + } + + @Override + protected void closeInternal(CloseStatus status) throws IOException { + getDelegateSession().close(new CloseReason(CloseCodes.getCloseCode(status.getCode()), status.getReason())); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java deleted file mode 100644 index 131105c1953..00000000000 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.socket.adapter; - -import java.io.IOException; -import java.net.URI; -import java.security.Principal; - -import javax.websocket.CloseReason; -import javax.websocket.CloseReason.CloseCodes; - -import org.springframework.util.Assert; -import org.springframework.util.StringUtils; -import org.springframework.web.socket.BinaryMessage; -import org.springframework.web.socket.CloseStatus; -import org.springframework.web.socket.TextMessage; -import org.springframework.web.socket.WebSocketSession; - -/** - * Adapts a standard {@link javax.websocket.Session} to {@link WebSocketSession}. - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class StandardWebSocketSessionAdapter extends AbstractWebSocketSesssionAdapter { - - private javax.websocket.Session session; - - private URI uri; - - private String remoteHostName; - - private String remoteAddress; - - private String protocol; - - - @Override - public void initSession(javax.websocket.Session session) { - Assert.notNull(session, "session must not be null"); - this.session = session; - - if (this.protocol == null) { - if (StringUtils.hasText(session.getNegotiatedSubprotocol())) { - this.protocol = session.getNegotiatedSubprotocol(); - } - } - } - - @Override - public String getId() { - return this.session.getId(); - } - - @Override - public URI getUri() { - return this.uri; - } - - @Override - public void setUri(URI uri) { - this.uri = uri; - } - - - @Override - public boolean isSecure() { - return this.session.isSecure(); - } - - @Override - public Principal getPrincipal() { - return this.session.getUserPrincipal(); - } - - @Override - public void setPrincipal(Principal principal) { - // ignore - } - - @Override - public String getRemoteHostName() { - return this.remoteHostName; - } - - @Override - public void setRemoteHostName(String name) { - this.remoteHostName = name; - } - - @Override - public String getRemoteAddress() { - return this.remoteAddress; - } - - @Override - public void setRemoteAddress(String address) { - this.remoteAddress = address; - } - - @Override - public String getAcceptedProtocol() { - return this.protocol; - } - - @Override - public void setAcceptedProtocol(String protocol) { - this.protocol = protocol; - } - - @Override - public boolean isOpen() { - return this.session.isOpen(); - } - - @Override - protected void sendTextMessage(TextMessage message) throws IOException { - this.session.getBasicRemote().sendText(message.getPayload(), message.isLast()); - } - - @Override - protected void sendBinaryMessage(BinaryMessage message) throws IOException { - this.session.getBasicRemote().sendBinary(message.getPayload(), message.isLast()); - } - - @Override - protected void closeInternal(CloseStatus status) throws IOException { - this.session.close(new CloseReason(CloseCodes.getCloseCode(status.getCode()), status.getReason())); - } - -} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java index e06a05990e3..1bf9afae4a0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java @@ -73,6 +73,9 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); Assert.notNull(uri, "uri must not be null"); + String scheme = uri.getScheme(); + Assert.isTrue(((scheme != null) && ("ws".equals(scheme) || "wss".equals(scheme))), "Invalid scheme: " + scheme); + if (logger.isDebugEnabled()) { logger.debug("Connecting to " + uri); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java index 1fcf7ef6aaf..6dff6e53fff 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java @@ -16,12 +16,10 @@ package org.springframework.web.socket.client; -import java.util.ArrayList; import java.util.List; import org.springframework.context.SmartLifecycle; import org.springframework.http.HttpHeaders; -import org.springframework.util.CollectionUtils; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator; @@ -43,9 +41,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { private WebSocketSession webSocketSession; - private final List protocols = new ArrayList(); - - private HttpHeaders headers; + private HttpHeaders headers = new HttpHeaders(); private final boolean syncClientLifecycle; @@ -76,24 +72,36 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { * any. */ public void setSubProtocols(List protocols) { - this.protocols.clear(); - if (!CollectionUtils.isEmpty(protocols)) { - this.protocols.addAll(protocols); - } + this.headers.setSecWebSocketProtocol(protocols); } /** * Return the configured sub-protocols to use. */ public List getSubProtocols() { - return this.protocols; + return this.headers.getSecWebSocketProtocol(); + } + + /** + * Set the origin to use. + */ + public void setOrigin(String origin) { + this.headers.setOrigin(origin); + } + + /** + * @return the configured origin. + */ + public String getOrigin() { + return this.headers.getOrigin(); } /** * Provide default headers to add to the WebSocket handshake request. */ public void setHeaders(HttpHeaders headers) { - this.headers = headers; + this.headers.clear(); + this.headers.putAll(headers); } /** @@ -122,14 +130,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { @Override protected void openConnection() throws Exception { - - HttpHeaders headers = new HttpHeaders(); - if (this.headers != null) { - headers.putAll(this.headers); - } - headers.setSecWebSocketProtocol(this.protocols); - - this.webSocketSession = this.client.doHandshake(this.webSocketHandler, headers, getUri()); + this.webSocketSession = this.client.doHandshake(this.webSocketHandler, this.headers, getUri()); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java index f6b3aa2d998..931262d15f7 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java @@ -16,8 +16,12 @@ package org.springframework.web.socket.client.endpoint; +import java.net.InetAddress; +import java.net.InetSocketAddress; import java.net.URI; +import java.net.UnknownHostException; import java.util.List; +import java.util.Locale; import java.util.Map; import javax.websocket.ClientEndpointConfig; @@ -31,8 +35,8 @@ import org.springframework.http.HttpHeaders; import org.springframework.util.Assert; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; -import org.springframework.web.socket.adapter.StandardEndpointAdapter; -import org.springframework.web.socket.adapter.StandardWebSocketSessionAdapter; +import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter; +import org.springframework.web.socket.adapter.StandardWebSocketSession; import org.springframework.web.socket.client.AbstractWebSocketClient; import org.springframework.web.socket.client.WebSocketConnectFailureException; @@ -60,19 +64,21 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { @Override protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, - HttpHeaders httpHeaders, URI uri, List protocols) throws WebSocketConnectFailureException { + HttpHeaders headers, URI uri, List protocols) throws WebSocketConnectFailureException { - StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter(); - session.setUri(uri); - session.setRemoteHostName(uri.getHost()); + int port = getPort(uri); + InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port); + InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port); + + StandardWebSocketSession session = new StandardWebSocketSession(headers, localAddress, remoteAddress); ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create(); - configBuidler.configurator(new StandardWebSocketClientConfigurator(httpHeaders)); + configBuidler.configurator(new StandardWebSocketClientConfigurator(headers)); configBuidler.preferredSubprotocols(protocols); try { // TODO: do not block - Endpoint endpoint = new StandardEndpointAdapter(webSocketHandler, session); + Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session); this.webSocketContainer.connectToServer(endpoint, configBuidler.build(), uri); return session; @@ -82,21 +88,38 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { } } + private InetAddress getLocalHost() { + try { + return InetAddress.getLocalHost(); + } + catch (UnknownHostException e) { + return InetAddress.getLoopbackAddress(); + } + } + + private int getPort(URI uri) { + if (uri.getPort() == -1) { + String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH); + return "wss".equals(scheme) ? 443 : 80; + } + return uri.getPort(); + } + private class StandardWebSocketClientConfigurator extends Configurator { - private final HttpHeaders httpHeaders; + private final HttpHeaders headers; - public StandardWebSocketClientConfigurator(HttpHeaders httpHeaders) { - this.httpHeaders = httpHeaders; + public StandardWebSocketClientConfigurator(HttpHeaders headers) { + this.headers = headers; } @Override - public void beforeRequest(Map> headers) { - headers.putAll(this.httpHeaders); + public void beforeRequest(Map> requestHeaders) { + requestHeaders.putAll(this.headers); if (logger.isDebugEnabled()) { - logger.debug("Handshake request headers: " + headers); + logger.debug("Handshake request headers: " + requestHeaders); } } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java index 20cad1e13f3..302b66ab334 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java @@ -24,8 +24,8 @@ import org.springframework.context.SmartLifecycle; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; -import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter; -import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter; +import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; +import org.springframework.web.socket.adapter.JettyWebSocketSession; import org.springframework.web.socket.client.AbstractWebSocketClient; import org.springframework.web.socket.client.WebSocketConnectFailureException; import org.springframework.web.util.UriComponents; @@ -130,7 +130,7 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma } @Override - public WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, HttpHeaders headers, + public WebSocketSession doHandshakeInternal(WebSocketHandler wsHandler, HttpHeaders headers, URI uri, List protocols) throws WebSocketConnectFailureException { ClientUpgradeRequest request = new ClientUpgradeRequest(); @@ -140,16 +140,13 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma request.setHeader(header, headers.get(header)); } - JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter(); - session.setUri(uri); - session.setRemoteHostName(uri.getHost()); - - JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler, session); + JettyWebSocketSession wsSession = new JettyWebSocketSession(null); + JettyWebSocketHandlerAdapter listener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession); try { // TODO: do not block this.client.connect(listener, uri, request).get(); - return session; + return wsSession; } catch (Exception e) { throw new WebSocketConnectFailureException("Failed to connect to " + uri, e); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java index e57c0ec4983..87587e9dbd6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java @@ -201,10 +201,14 @@ public class DefaultHandshakeHandler implements HandshakeHandler { protected String selectProtocol(List requestedProtocols) { if (requestedProtocols != null) { + if (logger.isDebugEnabled()) { + logger.debug("Requested sub-protocol(s): " + requestedProtocols + + ", supported sub-protocol(s): " + this.supportedProtocols); + } for (String protocol : requestedProtocols) { if (this.supportedProtocols.contains(protocol.toLowerCase())) { if (logger.isDebugEnabled()) { - logger.debug("Selected sub-protocol '" + protocol + "'"); + logger.debug("Selected sub-protocol: '" + protocol + "'"); } return protocol; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java index 14d17643598..0a02e477618 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java @@ -17,16 +17,18 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.net.InetSocketAddress; import javax.websocket.Endpoint; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.web.socket.WebSocketHandler; -import org.springframework.web.socket.adapter.StandardEndpointAdapter; -import org.springframework.web.socket.adapter.StandardWebSocketSessionAdapter; +import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter; +import org.springframework.web.socket.adapter.StandardWebSocketSession; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.RequestUpgradeStrategy; @@ -40,17 +42,19 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS protected final Log logger = LogFactory.getLog(getClass()); - private final ServerWebSocketSessionInitializer wsSessionInitializer = new ServerWebSocketSessionInitializer(); - @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String protocol, WebSocketHandler handler) throws IOException, HandshakeFailureException { + String acceptedProtocol, WebSocketHandler wsHandler) throws IOException, HandshakeFailureException { - StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter(); - this.wsSessionInitializer.initialize(request, response, protocol, session); - StandardEndpointAdapter endpoint = new StandardEndpointAdapter(handler, session); - upgradeInternal(request, response, protocol, endpoint); + HttpHeaders headers = request.getHeaders(); + InetSocketAddress localAddress = request.getLocalAddress(); + InetSocketAddress remoteAddress = request.getRemoteAddress(); + + StandardWebSocketSession wsSession = new StandardWebSocketSession(headers, localAddress, remoteAddress); + StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, wsSession); + + upgradeInternal(request, response, acceptedProtocol, endpoint); } protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java index 8690fe6f33a..e6b30aa0d3c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java @@ -33,8 +33,8 @@ import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.socket.WebSocketHandler; -import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter; -import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter; +import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; +import org.springframework.web.socket.adapter.JettyWebSocketSession; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.RequestUpgradeStrategy; @@ -59,8 +59,6 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { private WebSocketServerFactory factory; - private final ServerWebSocketSessionInitializer wsSessionInitializer = new ServerWebSocketSessionInitializer(); - public JettyRequestUpgradeStrategy() { this.factory = new WebSocketServerFactory(); @@ -87,7 +85,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String protocol, WebSocketHandler webSocketHandler) throws IOException { + String protocol, WebSocketHandler wsHandler) throws IOException { Assert.isInstanceOf(ServletServerHttpRequest.class, request); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -100,14 +98,13 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { throw new HandshakeFailureException("Not a WebSocket request"); } - JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter(); - this.wsSessionInitializer.initialize(request, response, protocol, session); - JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler, session); + JettyWebSocketSession wsSession = new JettyWebSocketSession(request.getPrincipal()); + JettyWebSocketHandlerAdapter wsListener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession); - servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, listener); + servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, wsListener); if (!this.factory.acceptWebSocket(servletRequest, servletResponse)) { - // should never happen + // should not happen throw new HandshakeFailureException("WebSocket request not accepted by Jetty"); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java deleted file mode 100644 index 50fef1e7441..00000000000 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.socket.server.support; - -import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; -import org.springframework.web.socket.WebSocketSession; -import org.springframework.web.socket.adapter.ConfigurableWebSocketSession; - -/** - * Copies information from the handshake HTTP request and response to a given - * {@link WebSocketSession}. - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class ServerWebSocketSessionInitializer { - - public void initialize(ServerHttpRequest request, ServerHttpResponse response, - String protocol, ConfigurableWebSocketSession session) { - - session.setUri(request.getURI()); - session.setRemoteHostName(request.getRemoteHostName()); - session.setRemoteAddress(request.getRemoteAddress()); - session.setPrincipal(request.getPrincipal()); - session.setAcceptedProtocol(protocol); - } - -} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java index b03c8d49855..e689d882ddc 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java @@ -52,7 +52,7 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, - String selectedProtocol, Endpoint endpoint) throws IOException { + String acceptedProtocol, Endpoint endpoint) throws IOException { Assert.isTrue(request instanceof ServletServerHttpRequest); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -82,7 +82,7 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg ServerEndpointConfig endpointConfig = new ServerEndpointRegistration("/shouldntmatter", endpoint); upgradeHandler.preInit(endpoint, endpointConfig, serverContainer, webSocketRequest, - selectedProtocol, Collections. emptyMap(), servletRequest.isSecure()); + acceptedProtocol, Collections. emptyMap(), servletRequest.isSecure()); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpReceivingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpReceivingTransportHandler.java index 04c68e37645..79de7ab895b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpReceivingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpReceivingTransportHandler.java @@ -47,9 +47,6 @@ public abstract class AbstractHttpReceivingTransportHandler public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, WebSocketSession wsSession) throws SockJsException { - // TODO: check "Sec-WebSocket-Protocol" header - // https://github.com/sockjs/sockjs-client/issues/130 - Assert.notNull(wsSession, "No session"); AbstractHttpSockJsSession sockJsSession = (AbstractHttpSockJsSession) wsSession; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java index 52e0b81a5c7..e6fb18650ed 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java @@ -43,10 +43,14 @@ public abstract class AbstractHttpSendingTransportHandler extends TransportHandl public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, WebSocketSession wsSession) throws SockJsException { + AbstractHttpSockJsSession sockJsSession = (AbstractHttpSockJsSession) wsSession; + + String protocol = null; // TODO: https://github.com/sockjs/sockjs-client/issues/130 + sockJsSession.setAcceptedProtocol(protocol); + // Set content type before writing response.getHeaders().setContentType(getContentType()); - AbstractHttpSockJsSession sockJsSession = (AbstractHttpSockJsSession) wsSession; handleRequestInternal(request, response, sockJsSession); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java index 04166fa0162..a1f708c7874 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java @@ -42,7 +42,6 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.server.DefaultHandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler; -import org.springframework.web.socket.server.support.ServerWebSocketSessionInitializer; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.support.AbstractSockJsService; @@ -77,8 +76,6 @@ public class DefaultSockJsService extends AbstractSockJsService { private final Map sessions = new ConcurrentHashMap(); - private final ServerWebSocketSessionInitializer sessionInitializer = new ServerWebSocketSessionInitializer(); - private ScheduledFuture sessionCleanupTask; @@ -279,8 +276,6 @@ public class DefaultSockJsService extends AbstractSockJsService { } logger.debug("Creating new session with session id \"" + sessionId + "\""); session = sessionFactory.createSession(sessionId, handler); - String protocol = null; // TODO: https://github.com/sockjs/sockjs-client/issues/130 - this.sessionInitializer.initialize(request, response, protocol, session); this.sessions.put(sessionId, session); return session; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandler.java index 59b6f74ee49..3b88cfccae8 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandler.java @@ -69,7 +69,7 @@ public class SockJsWebSocketHandler extends TextWebSocketHandlerAdapter { @Override public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception { Assert.isTrue(this.sessionCount.compareAndSet(0, 1), "Unexpected connection"); - this.sockJsSession.initWebSocketSession(wsSession); + this.sockJsSession.afterSessionInitialized(wsSession); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index c9066187c77..8f85b4ca2ed 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -17,9 +17,12 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; +import java.net.InetSocketAddress; +import java.security.Principal; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; +import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpAsyncRequestControl; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -51,12 +54,56 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private String protocol; + private HttpHeaders handshakeHeaders; - public AbstractHttpSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + private Principal principal; + + private InetSocketAddress localAddress; + + private InetSocketAddress remoteAddress; + + + public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler) { + super(id, config, wsHandler); } + @Override + public HttpHeaders getHandshakeHeaders() { + return this.handshakeHeaders; + } + + protected void setHandshakeHeaders(HttpHeaders handshakeHeaders) { + this.handshakeHeaders = handshakeHeaders; + } + + @Override + public Principal getPrincipal() { + return this.principal; + } + + protected void setPrincipal(Principal principal) { + this.principal = principal; + } + + @Override + public InetSocketAddress getLocalAddress() { + return this.localAddress; + } + + protected void setLocalAddress(InetSocketAddress localAddress) { + this.localAddress = localAddress; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return this.remoteAddress; + } + + protected void setRemoteAddress(InetSocketAddress remoteAddress) { + this.remoteAddress = remoteAddress; + } + /** * Unlike WebSocket where sub-protocol negotiation is part of the * initial handshake, in HTTP transports the same negotiation must @@ -87,6 +134,12 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { tryCloseWithSockJsTransportError(t, CloseStatus.SERVER_ERROR); throw new SockJsTransportFailureException("Failed to send \"open\" frame", getId(), t); } + + this.handshakeHeaders = request.getHeaders(); + this.principal = request.getPrincipal(); + this.localAddress = request.getLocalAddress(); + this.remoteAddress = request.getRemoteAddress(); + try { delegateConnectionEstablished(); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java index 6de66c5a9b0..0a94cce68ca 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java @@ -35,7 +35,6 @@ import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; -import org.springframework.web.socket.adapter.ConfigurableWebSocketSession; import org.springframework.web.socket.sockjs.SockJsMessageDeliveryException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; @@ -46,7 +45,7 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; * @author Rossen Stoyanchev * @since 4.0 */ -public abstract class AbstractSockJsSession implements ConfigurableWebSocketSession { +public abstract class AbstractSockJsSession implements WebSocketSession { protected final Log logger = LogFactory.getLog(getClass()); @@ -97,46 +96,6 @@ public abstract class AbstractSockJsSession implements ConfigurableWebSocketSess return this.uri; } - @Override - public void setUri(URI uri) { - this.uri = uri; - } - - @Override - public boolean isSecure() { - return "wss".equals(this.uri.getSchemeSpecificPart()); - } - - @Override - public String getRemoteHostName() { - return this.remoteHostName; - } - - @Override - public void setRemoteHostName(String remoteHostName) { - this.remoteHostName = remoteHostName; - } - - @Override - public String getRemoteAddress() { - return this.remoteAddress; - } - - @Override - public void setRemoteAddress(String remoteAddress) { - this.remoteAddress = remoteAddress; - } - - @Override - public Principal getPrincipal() { - return this.principal; - } - - @Override - public void setPrincipal(Principal principal) { - this.principal = principal; - } - public SockJsServiceConfig getSockJsServiceConfig() { return this.sockJsServiceConfig; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java index bb2f48b80c7..3f93b23cbcd 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java @@ -17,12 +17,17 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; +import java.net.InetSocketAddress; +import java.security.Principal; +import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.adapter.DelegatingWebSocketSession; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; import org.springframework.web.socket.sockjs.support.frame.SockJsMessageCodec; @@ -33,47 +38,69 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsMessageCodec; * @author Rossen Stoyanchev * @since 4.0 */ -public class WebSocketServerSockJsSession extends AbstractSockJsSession { +public class WebSocketServerSockJsSession extends AbstractSockJsSession + implements DelegatingWebSocketSession { - private WebSocketSession webSocketSession; + private WebSocketSession wsSession; - public WebSocketServerSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { - super(sessionId, config, handler); + public WebSocketServerSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler) { + super(id, config, wsHandler); } + @Override + public HttpHeaders getHandshakeHeaders() { + checkDelegateSessionInitialized(); + return this.wsSession.getHandshakeHeaders(); + } + + @Override + public Principal getPrincipal() { + checkDelegateSessionInitialized(); + return this.wsSession.getPrincipal(); + } + + @Override + public InetSocketAddress getLocalAddress() { + checkDelegateSessionInitialized(); + return this.wsSession.getLocalAddress(); + } + + @Override + public InetSocketAddress getRemoteAddress() { + checkDelegateSessionInitialized(); + return this.wsSession.getRemoteAddress(); + } @Override public String getAcceptedProtocol() { - if (this.webSocketSession == null) { - logger.warn("getAcceptedProtocol() invoked before WebSocketSession has been initialized."); - return null; - } - return this.webSocketSession.getAcceptedProtocol(); + checkDelegateSessionInitialized(); + return this.wsSession.getAcceptedProtocol(); } + private void checkDelegateSessionInitialized() { + Assert.state(this.wsSession != null, "WebSocketSession not yet initialized"); + } + + @Override - public void setAcceptedProtocol(String protocol) { - // ignore, webSocketSession should have it - } - - public void initWebSocketSession(WebSocketSession session) throws Exception { - this.webSocketSession = session; + public void afterSessionInitialized(WebSocketSession session) { + this.wsSession = session; try { TextMessage message = new TextMessage(SockJsFrame.openFrame().getContent()); - this.webSocketSession.sendMessage(message); + this.wsSession.sendMessage(message); + scheduleHeartbeat(); + delegateConnectionEstablished(); } - catch (IOException ex) { + catch (Exception ex) { tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); return; } - scheduleHeartbeat(); - delegateConnectionEstablished(); } @Override public boolean isActive() { - return ((this.webSocketSession != null) && this.webSocketSession.isOpen()); + return ((this.wsSession != null) && this.wsSession.isOpen()); } public void handleMessage(TextMessage message, WebSocketSession wsSession) throws Exception { @@ -109,13 +136,13 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession { logger.trace("Write " + frame); } TextMessage message = new TextMessage(frame.getContent()); - this.webSocketSession.sendMessage(message); + this.wsSession.sendMessage(message); } @Override protected void disconnect(CloseStatus status) throws IOException { - if (this.webSocketSession != null) { - this.webSocketSession.close(status); + if (this.wsSession != null) { + this.wsSession.close(status); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketListenerAdapterTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java similarity index 82% rename from spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketListenerAdapterTests.java rename to spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java index e55ec2a10f0..a964543b15d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketListenerAdapterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java @@ -25,17 +25,17 @@ import org.springframework.web.socket.WebSocketHandler; import static org.mockito.Mockito.*; /** - * Test fixture for {@link JettyWebSocketListenerAdapter}. + * Test fixture for {@link JettyWebSocketHandlerAdapter}. * * @author Rossen Stoyanchev */ -public class JettyWebSocketListenerAdapterTests { +public class JettyWebSocketHandlerAdapterTests { - private JettyWebSocketListenerAdapter adapter; + private JettyWebSocketHandlerAdapter adapter; private WebSocketHandler webSocketHandler; - private JettyWebSocketSessionAdapter webSocketSession; + private JettyWebSocketSession webSocketSession; private Session session; @@ -44,8 +44,8 @@ public class JettyWebSocketListenerAdapterTests { public void setup() { this.session = mock(Session.class); this.webSocketHandler = mock(WebSocketHandler.class); - this.webSocketSession = new JettyWebSocketSessionAdapter(); - this.adapter = new JettyWebSocketListenerAdapter(this.webSocketHandler, this.webSocketSession); + this.webSocketSession = new JettyWebSocketSession(null); + this.adapter = new JettyWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardEndpointAdapterTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java similarity index 84% rename from spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardEndpointAdapterTests.java rename to spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java index b38a39625cf..ede2a5e9514 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardEndpointAdapterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java @@ -31,17 +31,17 @@ import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; /** - * Test fixture for {@link StandardEndpointAdapter}. + * Test fixture for {@link StandardWebSocketHandlerAdapter}. * * @author Rossen Stoyanchev */ -public class StandardEndpointAdapterTests { +public class StandardWebSocketHandlerAdapterTests { - private StandardEndpointAdapter adapter; + private StandardWebSocketHandlerAdapter adapter; private WebSocketHandler webSocketHandler; - private StandardWebSocketSessionAdapter webSocketSession; + private StandardWebSocketSession webSocketSession; private Session session; @@ -50,8 +50,8 @@ public class StandardEndpointAdapterTests { public void setup() { this.session = mock(Session.class); this.webSocketHandler = mock(WebSocketHandler.class); - this.webSocketSession = new StandardWebSocketSessionAdapter(); - this.adapter = new StandardEndpointAdapter(this.webSocketHandler, this.webSocketSession); + this.webSocketSession = new StandardWebSocketSession(null, null, null); + this.adapter = new StandardWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java index 4e65f7663ce..05bdd9d3861 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java @@ -27,12 +27,12 @@ import javax.websocket.ClientEndpointConfig; import javax.websocket.Endpoint; import javax.websocket.WebSocketContainer; +import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; -import org.springframework.web.socket.adapter.StandardEndpointAdapter; import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; import static org.junit.Assert.*; @@ -45,40 +45,92 @@ import static org.mockito.Mockito.*; */ public class StandardWebSocketClientTests { + private StandardWebSocketClient wsClient; + + private WebSocketContainer wsContainer; + + private WebSocketHandler wsHandler; + + private HttpHeaders headers; + + + @Before + public void setup() { + this.headers = new HttpHeaders(); + this.wsHandler = new WebSocketHandlerAdapter(); + this.wsContainer = mock(WebSocketContainer.class); + this.wsClient = new StandardWebSocketClient(this.wsContainer); + } + + @Test - public void doHandshake() throws Exception { + public void localAddress() throws Exception { + URI uri = new URI("ws://example.com/abc"); + WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri); + + assertNotNull(session.getLocalAddress()); + assertEquals(80, session.getLocalAddress().getPort()); + } + + @Test + public void localAddressWss() throws Exception { + URI uri = new URI("wss://example.com/abc"); + WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri); + + assertNotNull(session.getLocalAddress()); + assertEquals(443, session.getLocalAddress().getPort()); + } + + @Test(expected=IllegalArgumentException.class) + public void localAddressNoScheme() throws Exception { + URI uri = new URI("example.com/abc"); + this.wsClient.doHandshake(this.wsHandler, this.headers, uri); + } + + @Test + public void remoteAddress() throws Exception { + URI uri = new URI("wss://example.com/abc"); + WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri); + + assertNotNull(session.getRemoteAddress()); + assertEquals("example.com", session.getRemoteAddress().getHostName()); + assertEquals(443, session.getLocalAddress().getPort()); + } + + @Test + public void headersWebSocketSession() throws Exception { URI uri = new URI("ws://example.com/abc"); - List subprotocols = Arrays.asList("abc"); + List protocols = Arrays.asList("abc"); + this.headers.setSecWebSocketProtocol(protocols); + this.headers.add("foo", "bar"); - HttpHeaders headers = new HttpHeaders(); - headers.setSecWebSocketProtocol(subprotocols); - headers.add("foo", "bar"); + WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri); - WebSocketHandler handler = new WebSocketHandlerAdapter(); - WebSocketContainer webSocketContainer = mock(WebSocketContainer.class); - StandardWebSocketClient client = new StandardWebSocketClient(webSocketContainer); - WebSocketSession session = client.doHandshake(handler, headers, uri); + assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), session.getHandshakeHeaders()); + } - ArgumentCaptor endpointArg = ArgumentCaptor.forClass(Endpoint.class); - ArgumentCaptor configArg = ArgumentCaptor.forClass(ClientEndpointConfig.class); - ArgumentCaptor uriArg = ArgumentCaptor.forClass(URI.class); + @Test + public void headersClientEndpointConfigurator() throws Exception { - verify(webSocketContainer).connectToServer(endpointArg.capture(), configArg.capture(), uriArg.capture()); + URI uri = new URI("ws://example.com/abc"); + List protocols = Arrays.asList("abc"); + this.headers.setSecWebSocketProtocol(protocols); + this.headers.add("foo", "bar"); - assertNotNull(endpointArg.getValue()); - assertEquals(StandardEndpointAdapter.class, endpointArg.getValue().getClass()); + this.wsClient.doHandshake(this.wsHandler, this.headers, uri); - ClientEndpointConfig config = configArg.getValue(); - assertEquals(subprotocols, config.getPreferredSubprotocols()); + ArgumentCaptor arg1 = ArgumentCaptor.forClass(Endpoint.class); + ArgumentCaptor arg2 = ArgumentCaptor.forClass(ClientEndpointConfig.class); + ArgumentCaptor arg3 = ArgumentCaptor.forClass(URI.class); + verify(this.wsContainer).connectToServer(arg1.capture(), arg2.capture(), arg3.capture()); + + ClientEndpointConfig endpointConfig = arg2.getValue(); + assertEquals(protocols, endpointConfig.getPreferredSubprotocols()); Map> map = new HashMap<>(); - config.getConfigurator().beforeRequest(map); + endpointConfig.getConfigurator().beforeRequest(map); assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), map); - - assertEquals(uri, uriArg.getValue()); - assertEquals(uri, session.getUri()); - assertEquals("example.com", session.getRemoteHostName()); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java index a28a92532fa..2eaa55054ea 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java @@ -33,8 +33,8 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.SocketUtils; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; -import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter; -import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter; +import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; +import org.springframework.web.socket.adapter.JettyWebSocketSession; import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; import static org.junit.Assert.*; @@ -113,8 +113,8 @@ public class JettyWebSocketClientTests { resp.setAcceptedSubProtocol(req.getSubProtocols().get(0)); } - JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter(); - return new JettyWebSocketListenerAdapter(webSocketHandler, session); + JettyWebSocketSession session = new JettyWebSocketSession(null); + return new JettyWebSocketHandlerAdapter(webSocketHandler, session); } }); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java index faf70ede1c3..4aad2c63363 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java @@ -17,9 +17,12 @@ package org.springframework.web.socket.sockjs.transport.session; import java.io.IOException; +import java.net.InetSocketAddress; +import java.security.Principal; import java.util.ArrayList; import java.util.List; +import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; @@ -29,6 +32,14 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; */ public class TestSockJsSession extends AbstractSockJsSession { + private HttpHeaders headers; + + private Principal principal; + + private InetSocketAddress localAddress; + + private InetSocketAddress remoteAddress; + private boolean active; private final List sockJsFrames = new ArrayList<>(); @@ -48,12 +59,76 @@ public class TestSockJsSession extends AbstractSockJsSession { super(sessionId, config, handler); } + + @Override + public HttpHeaders getHandshakeHeaders() { + return this.headers; + } + + /** + * @return the headers + */ + public HttpHeaders getHeaders() { + return this.headers; + } + + /** + * @param headers the headers to set + */ + public void setHeaders(HttpHeaders headers) { + this.headers = headers; + } + + /** + * @return the principal + */ + @Override + public Principal getPrincipal() { + return this.principal; + } + + /** + * @param principal the principal to set + */ + public void setPrincipal(Principal principal) { + this.principal = principal; + } + + /** + * @return the localAddress + */ + @Override + public InetSocketAddress getLocalAddress() { + return this.localAddress; + } + + /** + * @param remoteAddress the remoteAddress to set + */ + public void setLocalAddress(InetSocketAddress localAddress) { + this.localAddress = localAddress; + } + + /** + * @return the remoteAddress + */ + @Override + public InetSocketAddress getRemoteAddress() { + return this.remoteAddress; + } + + /** + * @param remoteAddress the remoteAddress to set + */ + public void setRemoteAddress(InetSocketAddress remoteAddress) { + this.remoteAddress = remoteAddress; + } + @Override public String getAcceptedProtocol() { return this.subProtocol; } - @Override public void setAcceptedProtocol(String protocol) { this.subProtocol = protocol; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java index fcd68074aec..003f2503c0b 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java @@ -27,7 +27,6 @@ import org.junit.Test; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHandler; -import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession; import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSessionTests.TestWebSocketServerSockJsSession; import org.springframework.web.socket.support.TestWebSocketSession; @@ -61,7 +60,7 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession public void isActive() throws Exception { assertFalse(this.session.isActive()); - this.session.initWebSocketSession(this.webSocketSession); + this.session.afterSessionInitialized(this.webSocketSession); assertTrue(this.session.isActive()); this.webSocketSession.setOpen(false); @@ -69,9 +68,9 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession } @Test - public void initWebSocketSession() throws Exception { + public void afterSessionInitialized() throws Exception { - this.session.initWebSocketSession(this.webSocketSession); + this.session.afterSessionInitialized(this.webSocketSession); assertEquals("Open frame not sent", Collections.singletonList(new TextMessage("o")), this.webSocketSession.getSentMessages()); @@ -110,7 +109,7 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession @Test public void sendMessageInternal() throws Exception { - this.session.initWebSocketSession(this.webSocketSession); + this.session.afterSessionInitialized(this.webSocketSession); this.session.sendMessageInternal("x"); assertEquals(Arrays.asList(new TextMessage("o"), new TextMessage("a[\"x\"]")), @@ -122,7 +121,7 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession @Test public void disconnect() throws Exception { - this.session.initWebSocketSession(this.webSocketSession); + this.session.afterSessionInitialized(this.webSocketSession); this.session.close(CloseStatus.NOT_ACCEPTABLE); assertEquals(CloseStatus.NOT_ACCEPTABLE, this.webSocketSession.getCloseStatus()); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java index 5b6482a7818..53d78aa1ea8 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java @@ -17,11 +17,13 @@ package org.springframework.web.socket.support; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; import java.util.ArrayList; import java.util.List; +import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; @@ -37,13 +39,11 @@ public class TestWebSocketSession implements WebSocketSession { private URI uri; - private boolean secure; - private Principal principal; - private String remoteHostName; + private InetSocketAddress localAddress; - private String remoteAddress; + private InetSocketAddress remoteAddress; private String protocol; @@ -53,6 +53,8 @@ public class TestWebSocketSession implements WebSocketSession { private CloseStatus status; + private HttpHeaders headers; + /** * @return the id @@ -84,19 +86,24 @@ public class TestWebSocketSession implements WebSocketSession { this.uri = uri; } - /** - * @return the secure - */ + @Override - public boolean isSecure() { - return this.secure; + public HttpHeaders getHandshakeHeaders() { + return this.headers; } /** - * @param secure the secure to set + * @return the headers */ - public void setSecure(boolean secure) { - this.secure = secure; + public HttpHeaders getHeaders() { + return this.headers; + } + + /** + * @param headers the headers to set + */ + public void setHeaders(HttpHeaders headers) { + this.headers = headers; } /** @@ -115,32 +122,32 @@ public class TestWebSocketSession implements WebSocketSession { } /** - * @return the remoteHostName + * @return the localAddress */ @Override - public String getRemoteHostName() { - return this.remoteHostName; + public InetSocketAddress getLocalAddress() { + return this.localAddress; } /** - * @param remoteHostName the remoteHostName to set + * @param remoteAddress the remoteAddress to set */ - public void setRemoteHostName(String remoteHostName) { - this.remoteHostName = remoteHostName; + public void setLocalAddress(InetSocketAddress localAddress) { + this.localAddress = localAddress; } /** * @return the remoteAddress */ @Override - public String getRemoteAddress() { + public InetSocketAddress getRemoteAddress() { return this.remoteAddress; } /** * @param remoteAddress the remoteAddress to set */ - public void setRemoteAddress(String remoteAddress) { + public void setRemoteAddress(InetSocketAddress remoteAddress) { this.remoteAddress = remoteAddress; }