From c28ce0e2bd96fa93c20282403400e39e160601c7 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 23 Apr 2013 17:44:53 -0400 Subject: [PATCH] Ensure WebSocketHandlerEndpoint can connect only once WebSocketHandlerEndpoint and SockJsWebSocketHandler are stateful wrappers that are not intended to be used with one client connection. --- .../sockjs/AbstractSockJsSession.java | 13 +-- .../transport/SockJsWebSocketHandler.java | 10 ++- .../endpoint/WebSocketHandlerEndpoint.java | 79 ++++++++++--------- 3 files changed, 59 insertions(+), 43 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java index 29791860a8b..d8565455c17 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java @@ -43,7 +43,7 @@ public abstract class AbstractSockJsSession implements WebSocketSession { private final HandlerProvider handlerProvider; - private final TextMessageHandler handler; + private TextMessageHandler handler; private State state = State.NEW; @@ -61,10 +61,6 @@ public abstract class AbstractSockJsSession implements WebSocketSession { Assert.notNull(sessionId, "sessionId is required"); Assert.notNull(handlerProvider, "handlerProvider is required"); this.sessionId = sessionId; - - WebSocketHandler webSocketHandler = handlerProvider.getHandler(); - Assert.isInstanceOf(TextMessageHandler.class, webSocketHandler, "Expected a TextMessageHandler"); - this.handler = (TextMessageHandler) webSocketHandler; this.handlerProvider = handlerProvider; } @@ -127,9 +123,16 @@ public abstract class AbstractSockJsSession implements WebSocketSession { public void delegateConnectionEstablished() throws Exception { this.state = State.OPEN; + initHandler(); this.handler.afterConnectionEstablished(this); } + private void initHandler() { + WebSocketHandler webSocketHandler = handlerProvider.getHandler(); + Assert.isInstanceOf(TextMessageHandler.class, webSocketHandler, "Expected a TextMessageHandler"); + this.handler = (TextMessageHandler) webSocketHandler; + } + public void delegateMessages(String[] messages) throws Exception { for (String message : messages) { this.handler.handleTextMessage(new TextMessage(message), this); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java index 776df0e30b3..eef056aab12 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java @@ -17,6 +17,7 @@ package org.springframework.sockjs.server.transport; import java.io.IOException; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -53,6 +54,8 @@ public class SockJsWebSocketHandler implements TextMessageHandler { private AbstractSockJsSession session; + private final AtomicInteger sessionCount = new AtomicInteger(0); + // TODO: JSON library used must be configurable private final ObjectMapper objectMapper = new ObjectMapper(); @@ -70,6 +73,7 @@ public class SockJsWebSocketHandler implements TextMessageHandler { @Override public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception { + Assert.isTrue(this.sessionCount.compareAndSet(0, 1), "Unexpected connection"); this.session = new WebSocketServerSockJsSession(wsSession, getSockJsConfig()); } @@ -80,14 +84,16 @@ public class SockJsWebSocketHandler implements TextMessageHandler { logger.trace("Ignoring empty message"); return; } + String[] messages; try { - String[] messages = this.objectMapper.readValue(payload, String[].class); - this.session.delegateMessages(messages); + messages = this.objectMapper.readValue(payload, String[].class); } catch (IOException e) { logger.error("Broken data received. Terminating WebSocket connection abruptly", e); wsSession.close(); + return; } + this.session.delegateMessages(messages); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java b/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java index be728fdcc1f..7f7bd50a45d 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java @@ -16,6 +16,8 @@ package org.springframework.websocket.endpoint; +import java.util.concurrent.atomic.AtomicInteger; + import javax.websocket.CloseReason; import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; @@ -51,6 +53,8 @@ public class WebSocketHandlerEndpoint extends Endpoint { private WebSocketSession webSocketSession; + private final AtomicInteger sessionCount = new AtomicInteger(0); + public WebSocketHandlerEndpoint(HandlerProvider handlerProvider) { Assert.notNull(handlerProvider, "handlerProvider is required"); @@ -59,48 +63,54 @@ public class WebSocketHandlerEndpoint extends Endpoint { @Override public void onOpen(final javax.websocket.Session session, EndpointConfig config) { - if (logger.isDebugEnabled()) { - logger.debug("Client connected, WebSocket session id=" + session.getId() + ", uri=" + session.getRequestURI()); - } - try { - this.handler = handlerProvider.getHandler(); - this.webSocketSession = new StandardWebSocketSession(session); - if (this.handler instanceof TextMessageHandler) { - session.addMessageHandler(new MessageHandler.Whole() { + Assert.isTrue(this.sessionCount.compareAndSet(0, 1), "Unexpected connection"); + + if (logger.isDebugEnabled()) { + logger.debug("Client connected, javax.websocket.Session id=" + + session.getId() + ", uri=" + session.getRequestURI()); + } + + this.webSocketSession = new StandardWebSocketSession(session); + this.handler = handlerProvider.getHandler(); + + if (this.handler instanceof TextMessageHandler) { + session.addMessageHandler(new MessageHandler.Whole() { + @Override + public void onMessage(String message) { + handleTextMessage(session, message); + } + }); + } + else if (this.handler instanceof BinaryMessageHandler) { + if (this.handler instanceof PartialMessageHandler) { + session.addMessageHandler(new MessageHandler.Partial() { @Override - public void onMessage(String message) { - handleTextMessage(session, message); + public void onMessage(byte[] messagePart, boolean isLast) { + handleBinaryMessage(session, messagePart, isLast); } }); } - else if (this.handler instanceof BinaryMessageHandler) { - if (this.handler instanceof PartialMessageHandler) { - session.addMessageHandler(new MessageHandler.Partial() { - @Override - public void onMessage(byte[] messagePart, boolean isLast) { - handleBinaryMessage(session, messagePart, isLast); - } - }); - } - else { - session.addMessageHandler(new MessageHandler.Whole() { - @Override - public void onMessage(byte[] message) { - handleBinaryMessage(session, message, true); - } - }); - } - } else { + session.addMessageHandler(new MessageHandler.Whole() { + @Override + public void onMessage(byte[] message) { + handleBinaryMessage(session, message, true); + } + }); + } + } + else { + if (logger.isWarnEnabled()) { logger.warn("WebSocketHandler handles neither text nor binary messages: " + this.handler); } + } + try { this.handler.afterConnectionEstablished(this.webSocketSession); } catch (Throwable ex) { - // TODO - logger.error("Error while processing new session", ex); + this.handler.handleError(ex, this.webSocketSession); } } @@ -113,8 +123,7 @@ public class WebSocketHandlerEndpoint extends Endpoint { ((TextMessageHandler) handler).handleTextMessage(textMessage, this.webSocketSession); } catch (Throwable ex) { - // TODO - logger.error("Error while processing message", ex); + this.handler.handleError(ex, this.webSocketSession); } } @@ -127,8 +136,7 @@ public class WebSocketHandlerEndpoint extends Endpoint { ((BinaryMessageHandler) handler).handleBinaryMessage(binaryMessage, this.webSocketSession); } catch (Throwable ex) { - // TODO - logger.error("Error while processing message", ex); + this.handler.handleError(ex, this.webSocketSession); } } @@ -142,7 +150,6 @@ public class WebSocketHandlerEndpoint extends Endpoint { this.handler.afterConnectionClosed(closeStatus, this.webSocketSession); } catch (Throwable ex) { - // TODO logger.error("Error while processing session closing", ex); } finally { @@ -157,7 +164,7 @@ public class WebSocketHandlerEndpoint extends Endpoint { this.handler.handleError(exception, this.webSocketSession); } catch (Throwable ex) { - // TODO + // TODO: close the session? logger.error("Failed to handle error", ex); } }