diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java b/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java index 4f54b8dc43..364a9597b7 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java @@ -134,6 +134,17 @@ public final class CloseStatus { public static final CloseStatus TLS_HANDSHAKE_FAILURE = new CloseStatus(1015); + /** + * Indicates that a session has become unreliable (e.g. timed out while sending + * a message) and extra care should be exercised while closing the session in + * order to avoid locking additional threads. + * + *

NOTE: Spring Framework specific status code. + */ + public static final CloseStatus SESSION_NOT_RELIABLE = new CloseStatus(4500); + + + private final int code; private final String reason; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java index a519f4581e..33947496c2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java @@ -18,6 +18,7 @@ package org.springframework.web.socket.handler; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; @@ -58,7 +59,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat private final int sendTimeLimit; - private volatile boolean sessionLimitExceeded; + private volatile boolean limitExceeded; private final Lock flushLock = new ReentrantLock(); @@ -85,7 +86,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat public void sendMessage(WebSocketMessage message) throws IOException { - if (this.sessionLimitExceeded) { + if (this.limitExceeded) { return; } @@ -94,8 +95,8 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat do { if (!tryFlushMessageBuffer()) { - if (logger.isDebugEnabled()) { - logger.debug("Another send already in progress, session id '" + + if (logger.isTraceEnabled()) { + logger.trace("Another send already in progress, session id '" + getId() + "'" + ", in-progress send time " + getTimeSinceSendStarted() + " (ms)" + ", buffer size " + this.bufferSize + " bytes"); } @@ -107,7 +108,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat } private boolean tryFlushMessageBuffer() throws IOException { - if (this.flushLock.tryLock() && !this.sessionLimitExceeded) { + if (this.flushLock.tryLock() && !this.limitExceeded) { try { while (true) { WebSocketMessage messageToSend = this.buffer.poll(); @@ -130,17 +131,22 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat } private void checkSessionLimits() throws IOException { - if (this.closeLock.tryLock() && !this.sessionLimitExceeded) { + if (this.closeLock.tryLock() && !this.limitExceeded) { try { if (getTimeSinceSendStarted() > this.sendTimeLimit) { - sessionLimitReached( - "Message send time " + getTimeSinceSendStarted() + - " (ms) exceeded the allowed limit " + this.sendTimeLimit); + + String errorMessage = "Message send time " + getTimeSinceSendStarted() + + " (ms) exceeded the allowed limit " + this.sendTimeLimit; + + sessionLimitReached(errorMessage, CloseStatus.SESSION_NOT_RELIABLE); } else if (this.bufferSize.get() > this.bufferSizeLimit) { - sessionLimitReached( - "The send buffer size " + this.bufferSize.get() + " bytes for " + - "session '" + getId() + " exceeded the allowed limit " + this.bufferSizeLimit); + + String errorMessage = "The send buffer size " + this.bufferSize.get() + " bytes for " + + "session '" + getId() + " exceeded the allowed limit " + this.bufferSizeLimit; + + sessionLimitReached(errorMessage, + (getTimeSinceSendStarted() >= 10000 ? CloseStatus.SESSION_NOT_RELIABLE : null)); } } finally { @@ -149,9 +155,9 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat } } - private void sessionLimitReached(String reason) { - this.sessionLimitExceeded = true; - throw new SessionLimitExceededException(reason); + private void sessionLimitReached(String reason, CloseStatus status) { + this.limitExceeded = true; + throw new SessionLimitExceededException(reason, status); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java index 5f5bad5666..96c3b3a6ae 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.handler; +import org.springframework.web.socket.CloseStatus; + /** * Raised when a WebSocket session has exceeded limits it has been configured * for, e.g. timeout, buffer size, etc. @@ -26,9 +28,17 @@ package org.springframework.web.socket.handler; @SuppressWarnings("serial") public class SessionLimitExceededException extends RuntimeException { + private final CloseStatus status; - public SessionLimitExceededException(String message) { + + public SessionLimitExceededException(String message, CloseStatus status) { super(message); + this.status = (status != null) ? status : CloseStatus.NO_STATUS_CODE; + } + + + public CloseStatus getStatus() { + return this.status; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index 21b10891a1..20968d6db5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -235,8 +235,16 @@ public class SubProtocolWebSocketHandler } protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) { + + String protocol = null; + try { + protocol = session.getAcceptedProtocol(); + } + catch (Exception ex) { + logger.warn("Ignoring protocol in WebSocket session after failure to obtain it: " + ex.toString()); + } + SubProtocolHandler handler; - String protocol = session.getAcceptedProtocol(); if (!StringUtils.isEmpty(protocol)) { handler = this.protocolHandlers.get(protocol); Assert.state(handler != null, @@ -283,13 +291,13 @@ public class SubProtocolWebSocketHandler try { findProtocolHandler(session).handleMessageToClient(session, message); } - catch (SessionLimitExceededException e) { + catch (SessionLimitExceededException ex) { try { - logger.error("Terminating session id '" + sessionId + "'", e); + logger.error("Terminating session id '" + sessionId + "'", ex); // Session may be unresponsive so clear first - clearSession(session, CloseStatus.NO_STATUS_CODE); - session.close(); + clearSession(session, ex.getStatus()); + session.close(ex.getStatus()); } catch (Exception secondException) { logger.error("Exception terminating session id '" + sessionId + "'", secondException); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java index cbc6b1d467..afd228d6f3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java @@ -96,6 +96,7 @@ public abstract class AbstractSockJsSession implements SockJsSession { private final Map attributes; + private volatile State state = State.NEW; private final long timeCreated = System.currentTimeMillis(); @@ -259,7 +260,7 @@ public abstract class AbstractSockJsSession implements SockJsSession { logger.debug("Closing " + this + ", " + status); } try { - if (isActive()) { + if (isActive() && !CloseStatus.SESSION_NOT_RELIABLE.equals(status)) { try { // bypass writeFrame writeFrameInternal(SockJsFrame.closeFrame(status.getCode(), status.getReason())); @@ -284,6 +285,10 @@ public abstract class AbstractSockJsSession implements SockJsSession { } } + /** + * Actually close the underlying WebSocket session or in the case of HTTP + * transports complete the underlying request. + */ protected abstract void disconnect(CloseStatus status) throws IOException; /** diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java index 6bfad1f910..38b725fe74 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.handler; import org.junit.Test; +import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; @@ -136,6 +137,7 @@ public class ConcurrentWebSocketSessionDecoratorTests { fail("Expected exception"); } catch (SessionLimitExceededException ex) { + assertEquals(CloseStatus.SESSION_NOT_RELIABLE, ex.getStatus()); } }