Add CloseStatus to indicate unreliable session

When a send timeout is detected, the WebSocket session is now closed
with a custom close status that indicates so. This allows skipping
parts of the close logic that may cause further hanging.

Issue: SPR-11450
This commit is contained in:
Rossen Stoyanchev 2014-03-23 00:51:33 -04:00
parent 4028a3b0bc
commit cbd5af3a03
6 changed files with 64 additions and 22 deletions

View File

@ -134,6 +134,17 @@ public final class CloseStatus {
public static final CloseStatus TLS_HANDSHAKE_FAILURE = new CloseStatus(1015); public static final CloseStatus TLS_HANDSHAKE_FAILURE = new CloseStatus(1015);
/**
* Indicates that a session has become unreliable (e.g. timed out while sending
* a message) and extra care should be exercised while closing the session in
* order to avoid locking additional threads.
*
* <p><strong>NOTE:</strong> Spring Framework specific status code.
*/
public static final CloseStatus SESSION_NOT_RELIABLE = new CloseStatus(4500);
private final int code; private final int code;
private final String reason; private final String reason;

View File

@ -18,6 +18,7 @@ package org.springframework.web.socket.handler;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
@ -58,7 +59,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
private final int sendTimeLimit; private final int sendTimeLimit;
private volatile boolean sessionLimitExceeded; private volatile boolean limitExceeded;
private final Lock flushLock = new ReentrantLock(); private final Lock flushLock = new ReentrantLock();
@ -85,7 +86,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
public void sendMessage(WebSocketMessage<?> message) throws IOException { public void sendMessage(WebSocketMessage<?> message) throws IOException {
if (this.sessionLimitExceeded) { if (this.limitExceeded) {
return; return;
} }
@ -94,8 +95,8 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
do { do {
if (!tryFlushMessageBuffer()) { if (!tryFlushMessageBuffer()) {
if (logger.isDebugEnabled()) { if (logger.isTraceEnabled()) {
logger.debug("Another send already in progress, session id '" + logger.trace("Another send already in progress, session id '" +
getId() + "'" + ", in-progress send time " + getTimeSinceSendStarted() + getId() + "'" + ", in-progress send time " + getTimeSinceSendStarted() +
" (ms)" + ", buffer size " + this.bufferSize + " bytes"); " (ms)" + ", buffer size " + this.bufferSize + " bytes");
} }
@ -107,7 +108,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
} }
private boolean tryFlushMessageBuffer() throws IOException { private boolean tryFlushMessageBuffer() throws IOException {
if (this.flushLock.tryLock() && !this.sessionLimitExceeded) { if (this.flushLock.tryLock() && !this.limitExceeded) {
try { try {
while (true) { while (true) {
WebSocketMessage<?> messageToSend = this.buffer.poll(); WebSocketMessage<?> messageToSend = this.buffer.poll();
@ -130,17 +131,22 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
} }
private void checkSessionLimits() throws IOException { private void checkSessionLimits() throws IOException {
if (this.closeLock.tryLock() && !this.sessionLimitExceeded) { if (this.closeLock.tryLock() && !this.limitExceeded) {
try { try {
if (getTimeSinceSendStarted() > this.sendTimeLimit) { if (getTimeSinceSendStarted() > this.sendTimeLimit) {
sessionLimitReached(
"Message send time " + getTimeSinceSendStarted() + String errorMessage = "Message send time " + getTimeSinceSendStarted() +
" (ms) exceeded the allowed limit " + this.sendTimeLimit); " (ms) exceeded the allowed limit " + this.sendTimeLimit;
sessionLimitReached(errorMessage, CloseStatus.SESSION_NOT_RELIABLE);
} }
else if (this.bufferSize.get() > this.bufferSizeLimit) { else if (this.bufferSize.get() > this.bufferSizeLimit) {
sessionLimitReached(
"The send buffer size " + this.bufferSize.get() + " bytes for " + String errorMessage = "The send buffer size " + this.bufferSize.get() + " bytes for " +
"session '" + getId() + " exceeded the allowed limit " + this.bufferSizeLimit); "session '" + getId() + " exceeded the allowed limit " + this.bufferSizeLimit;
sessionLimitReached(errorMessage,
(getTimeSinceSendStarted() >= 10000 ? CloseStatus.SESSION_NOT_RELIABLE : null));
} }
} }
finally { finally {
@ -149,9 +155,9 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
} }
} }
private void sessionLimitReached(String reason) { private void sessionLimitReached(String reason, CloseStatus status) {
this.sessionLimitExceeded = true; this.limitExceeded = true;
throw new SessionLimitExceededException(reason); throw new SessionLimitExceededException(reason, status);
} }
} }

View File

@ -16,6 +16,8 @@
package org.springframework.web.socket.handler; package org.springframework.web.socket.handler;
import org.springframework.web.socket.CloseStatus;
/** /**
* Raised when a WebSocket session has exceeded limits it has been configured * Raised when a WebSocket session has exceeded limits it has been configured
* for, e.g. timeout, buffer size, etc. * for, e.g. timeout, buffer size, etc.
@ -26,9 +28,17 @@ package org.springframework.web.socket.handler;
@SuppressWarnings("serial") @SuppressWarnings("serial")
public class SessionLimitExceededException extends RuntimeException { public class SessionLimitExceededException extends RuntimeException {
private final CloseStatus status;
public SessionLimitExceededException(String message) {
public SessionLimitExceededException(String message, CloseStatus status) {
super(message); super(message);
this.status = (status != null) ? status : CloseStatus.NO_STATUS_CODE;
}
public CloseStatus getStatus() {
return this.status;
} }
} }

View File

@ -235,8 +235,16 @@ public class SubProtocolWebSocketHandler
} }
protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) { protected final SubProtocolHandler findProtocolHandler(WebSocketSession session) {
String protocol = null;
try {
protocol = session.getAcceptedProtocol();
}
catch (Exception ex) {
logger.warn("Ignoring protocol in WebSocket session after failure to obtain it: " + ex.toString());
}
SubProtocolHandler handler; SubProtocolHandler handler;
String protocol = session.getAcceptedProtocol();
if (!StringUtils.isEmpty(protocol)) { if (!StringUtils.isEmpty(protocol)) {
handler = this.protocolHandlers.get(protocol); handler = this.protocolHandlers.get(protocol);
Assert.state(handler != null, Assert.state(handler != null,
@ -283,13 +291,13 @@ public class SubProtocolWebSocketHandler
try { try {
findProtocolHandler(session).handleMessageToClient(session, message); findProtocolHandler(session).handleMessageToClient(session, message);
} }
catch (SessionLimitExceededException e) { catch (SessionLimitExceededException ex) {
try { try {
logger.error("Terminating session id '" + sessionId + "'", e); logger.error("Terminating session id '" + sessionId + "'", ex);
// Session may be unresponsive so clear first // Session may be unresponsive so clear first
clearSession(session, CloseStatus.NO_STATUS_CODE); clearSession(session, ex.getStatus());
session.close(); session.close(ex.getStatus());
} }
catch (Exception secondException) { catch (Exception secondException) {
logger.error("Exception terminating session id '" + sessionId + "'", secondException); logger.error("Exception terminating session id '" + sessionId + "'", secondException);

View File

@ -96,6 +96,7 @@ public abstract class AbstractSockJsSession implements SockJsSession {
private final Map<String, Object> attributes; private final Map<String, Object> attributes;
private volatile State state = State.NEW; private volatile State state = State.NEW;
private final long timeCreated = System.currentTimeMillis(); private final long timeCreated = System.currentTimeMillis();
@ -259,7 +260,7 @@ public abstract class AbstractSockJsSession implements SockJsSession {
logger.debug("Closing " + this + ", " + status); logger.debug("Closing " + this + ", " + status);
} }
try { try {
if (isActive()) { if (isActive() && !CloseStatus.SESSION_NOT_RELIABLE.equals(status)) {
try { try {
// bypass writeFrame // bypass writeFrame
writeFrameInternal(SockJsFrame.closeFrame(status.getCode(), status.getReason())); writeFrameInternal(SockJsFrame.closeFrame(status.getCode(), status.getReason()));
@ -284,6 +285,10 @@ public abstract class AbstractSockJsSession implements SockJsSession {
} }
} }
/**
* Actually close the underlying WebSocket session or in the case of HTTP
* transports complete the underlying request.
*/
protected abstract void disconnect(CloseStatus status) throws IOException; protected abstract void disconnect(CloseStatus status) throws IOException;
/** /**

View File

@ -17,6 +17,7 @@
package org.springframework.web.socket.handler; package org.springframework.web.socket.handler;
import org.junit.Test; import org.junit.Test;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketMessage;
@ -136,6 +137,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
fail("Expected exception"); fail("Expected exception");
} }
catch (SessionLimitExceededException ex) { catch (SessionLimitExceededException ex) {
assertEquals(CloseStatus.SESSION_NOT_RELIABLE, ex.getStatus());
} }
} }