Add check for unused WebSocket sessions
Sessions connected to a STOMP endpoint are expected to receive some client messages. Having received none after successfully connecting could be an indication of proxy or network issue. This change adds periodic checks to see if we have not received any messages on a session which is an indication the session isn't going anywhere most likely due to a proxy issue (or unreliable network) and close those sessions. Issue: SPR-11884
This commit is contained in:
parent
98d6f7b443
commit
a3fa9c9797
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.springframework.web.socket.messaging;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
|
@ -24,6 +25,7 @@ import java.util.Map;
|
|||
import java.util.Set;
|
||||
import java.util.TreeMap;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
@ -64,8 +66,18 @@ import org.springframework.web.socket.handler.SessionLimitExceededException;
|
|||
public class SubProtocolWebSocketHandler implements WebSocketHandler,
|
||||
SubProtocolCapable, MessageHandler, SmartLifecycle {
|
||||
|
||||
/**
|
||||
* Sessions connected to this handler use a sub-protocol. Hence we expect to
|
||||
* receive some client messages. If we don't receive any within a minute, the
|
||||
* connection isn't doing well (proxy issue, slow network?) and can be closed.
|
||||
* @see #checkSessions()
|
||||
*/
|
||||
private final int TIME_TO_FIRST_MESSAGE = 60 * 1000;
|
||||
|
||||
|
||||
private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
|
||||
|
||||
|
||||
private final MessageChannel clientInboundChannel;
|
||||
|
||||
private final SubscribableChannel clientOutboundChannel;
|
||||
|
@ -75,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
|
|||
|
||||
private SubProtocolHandler defaultProtocolHandler;
|
||||
|
||||
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
|
||||
private final Map<String, WebSocketSessionHolder> sessions = new ConcurrentHashMap<String, WebSocketSessionHolder>();
|
||||
|
||||
private int sendTimeLimit = 10 * 1000;
|
||||
|
||||
private int sendBufferSizeLimit = 512 * 1024;
|
||||
|
||||
private volatile long lastSessionCheckTime = System.currentTimeMillis();
|
||||
|
||||
private final ReentrantLock sessionCheckLock = new ReentrantLock();
|
||||
|
||||
private final Object lifecycleMonitor = new Object();
|
||||
|
||||
private volatile boolean running = false;
|
||||
|
@ -214,12 +230,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
|
|||
this.clientOutboundChannel.unsubscribe(this);
|
||||
|
||||
// Notify sessions to stop flushing messages
|
||||
for (WebSocketSession session : this.sessions.values()) {
|
||||
for (WebSocketSessionHolder holder : this.sessions.values()) {
|
||||
try {
|
||||
session.close(CloseStatus.GOING_AWAY);
|
||||
holder.getSession().close(CloseStatus.GOING_AWAY);
|
||||
}
|
||||
catch (Throwable t) {
|
||||
logger.error("Failed to close session id '" + session.getId() + "': " + t.getMessage());
|
||||
logger.error("Failed to close '" + holder.getSession() + "': " + t.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -235,15 +251,11 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
|
|||
|
||||
@Override
|
||||
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
||||
|
||||
session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit());
|
||||
|
||||
this.sessions.put(session.getId(), session);
|
||||
this.sessions.put(session.getId(), new WebSocketSessionHolder(session));
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Started WebSocket session=" + session.getId() +
|
||||
", number of sessions=" + this.sessions.size());
|
||||
logger.debug("Started session " + session.getId() + ", number of sessions=" + this.sessions.size());
|
||||
}
|
||||
|
||||
findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel);
|
||||
}
|
||||
|
||||
|
@ -283,41 +295,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
|
|||
|
||||
@Override
|
||||
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
|
||||
findProtocolHandler(session).handleMessageFromClient(session, message, this.clientInboundChannel);
|
||||
SubProtocolHandler protocolHandler = findProtocolHandler(session);
|
||||
protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel);
|
||||
WebSocketSessionHolder holder = this.sessions.get(session.getId());
|
||||
if (holder != null) {
|
||||
holder.setHasHandledMessages();
|
||||
}
|
||||
else {
|
||||
// Should never happen
|
||||
throw new IllegalStateException("Session not found: " + session);
|
||||
}
|
||||
checkSessions();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleMessage(Message<?> message) throws MessagingException {
|
||||
|
||||
String sessionId = resolveSessionId(message);
|
||||
if (sessionId == null) {
|
||||
logger.error("sessionId not found in message " + message);
|
||||
return;
|
||||
}
|
||||
|
||||
WebSocketSession session = this.sessions.get(sessionId);
|
||||
if (session == null) {
|
||||
WebSocketSessionHolder holder = this.sessions.get(sessionId);
|
||||
if (holder == null) {
|
||||
logger.error("Session not found for session with id '" + sessionId + "', ignoring message " + message);
|
||||
return;
|
||||
}
|
||||
|
||||
WebSocketSession session = holder.getSession();
|
||||
try {
|
||||
findProtocolHandler(session).handleMessageToClient(session, message);
|
||||
}
|
||||
catch (SessionLimitExceededException ex) {
|
||||
try {
|
||||
logger.error("Terminating session id '" + sessionId + "'", ex);
|
||||
logger.error("Terminating '" + session + "'", ex);
|
||||
|
||||
// Session may be unresponsive so clear first
|
||||
clearSession(session, ex.getStatus());
|
||||
session.close(ex.getStatus());
|
||||
}
|
||||
catch (Exception secondException) {
|
||||
logger.error("Exception terminating session id '" + sessionId + "'", secondException);
|
||||
logger.error("Exception terminating '" + sessionId + "'", secondException);
|
||||
}
|
||||
}
|
||||
catch (Exception e) {
|
||||
logger.error("Failed to send message to client " + message, e);
|
||||
logger.error("Failed to send message to client " + message + " in " + session, e);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -337,6 +357,43 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
|
|||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Periodically check sessions to ensure they have received at least one
|
||||
* message or otherwise close them.
|
||||
*/
|
||||
private void checkSessions() throws IOException {
|
||||
long currentTime = System.currentTimeMillis();
|
||||
if (!isRunning() && currentTime - this.lastSessionCheckTime < TIME_TO_FIRST_MESSAGE) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (this.sessionCheckLock.tryLock()) {
|
||||
for (WebSocketSessionHolder holder : this.sessions.values()) {
|
||||
if (holder.hasHandledMessages()) {
|
||||
continue;
|
||||
}
|
||||
long timeSinceCreated = currentTime - holder.getCreateTime();
|
||||
if (holder.hasHandledMessages() || timeSinceCreated < TIME_TO_FIRST_MESSAGE) {
|
||||
continue;
|
||||
}
|
||||
WebSocketSession session = holder.getSession();
|
||||
if (logger.isErrorEnabled()) {
|
||||
logger.error("No messages received after " + timeSinceCreated + " ms. Closing " + holder);
|
||||
}
|
||||
try {
|
||||
session.close(CloseStatus.PROTOCOL_ERROR);
|
||||
}
|
||||
catch (Throwable t) {
|
||||
logger.error("Failed to close " + session, t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
finally {
|
||||
this.sessionCheckLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
|
||||
}
|
||||
|
@ -356,4 +413,45 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
|
|||
return false;
|
||||
}
|
||||
|
||||
|
||||
private static class WebSocketSessionHolder {
|
||||
|
||||
private final WebSocketSession session;
|
||||
|
||||
private final long createTime = System.currentTimeMillis();
|
||||
|
||||
private volatile boolean handledMessages;
|
||||
|
||||
|
||||
private WebSocketSessionHolder(WebSocketSession session) {
|
||||
this.session = session;
|
||||
}
|
||||
|
||||
public WebSocketSession getSession() {
|
||||
return this.session;
|
||||
}
|
||||
|
||||
public long getCreateTime() {
|
||||
return this.createTime;
|
||||
}
|
||||
|
||||
public void setHasHandledMessages() {
|
||||
this.handledMessages = true;
|
||||
}
|
||||
|
||||
public boolean hasHandledMessages() {
|
||||
return this.handledMessages;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
if (this.session instanceof ConcurrentWebSocketSessionDecorator) {
|
||||
return ((ConcurrentWebSocketSessionDecorator) this.session).getLastSession().toString();
|
||||
}
|
||||
else {
|
||||
return this.session.toString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,16 +17,24 @@
|
|||
package org.springframework.web.socket.messaging;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.MockitoAnnotations;
|
||||
import org.springframework.beans.DirectFieldAccessor;
|
||||
import org.springframework.messaging.MessageChannel;
|
||||
import org.springframework.messaging.SubscribableChannel;
|
||||
import org.springframework.web.socket.CloseStatus;
|
||||
import org.springframework.web.socket.TextMessage;
|
||||
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
|
||||
import org.springframework.web.socket.handler.TestWebSocketSession;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
/**
|
||||
|
@ -56,11 +64,9 @@ public class SubProtocolWebSocketHandlerTests {
|
|||
@Before
|
||||
public void setup() {
|
||||
MockitoAnnotations.initMocks(this);
|
||||
|
||||
this.webSocketHandler = new SubProtocolWebSocketHandler(this.inClientChannel, this.outClientChannel);
|
||||
when(stompHandler.getSupportedProtocols()).thenReturn(Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"));
|
||||
when(mqttHandler.getSupportedProtocols()).thenReturn(Arrays.asList("MQTT"));
|
||||
|
||||
this.session = new TestWebSocketSession();
|
||||
this.session.setId("1");
|
||||
}
|
||||
|
@ -140,4 +146,32 @@ public class SubProtocolWebSocketHandlerTests {
|
|||
this.webSocketHandler.afterConnectionEstablished(session);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void checkSession() throws Exception {
|
||||
TestWebSocketSession session1 = new TestWebSocketSession("id1");
|
||||
TestWebSocketSession session2 = new TestWebSocketSession("id2");
|
||||
session1.setAcceptedProtocol("v12.stomp");
|
||||
session2.setAcceptedProtocol("v12.stomp");
|
||||
|
||||
this.webSocketHandler.setProtocolHandlers(Arrays.asList(this.stompHandler));
|
||||
this.webSocketHandler.afterConnectionEstablished(session1);
|
||||
this.webSocketHandler.afterConnectionEstablished(session2);
|
||||
session1.setOpen(true);
|
||||
session2.setOpen(true);
|
||||
|
||||
long sixtyOneSecondsAgo = System.currentTimeMillis() - 61 * 1000;
|
||||
new DirectFieldAccessor(this.webSocketHandler).setPropertyValue("lastSessionCheckTime", sixtyOneSecondsAgo);
|
||||
Map<String, ?> sessions = (Map<String, ?>) new DirectFieldAccessor(this.webSocketHandler).getPropertyValue("sessions");
|
||||
new DirectFieldAccessor(sessions.get("id1")).setPropertyValue("createTime", sixtyOneSecondsAgo);
|
||||
new DirectFieldAccessor(sessions.get("id2")).setPropertyValue("createTime", sixtyOneSecondsAgo);
|
||||
|
||||
this.webSocketHandler.handleMessage(session1, new TextMessage("foo"));
|
||||
|
||||
assertTrue(session1.isOpen());
|
||||
assertFalse(session2.isOpen());
|
||||
assertNull(session1.getCloseStatus());
|
||||
assertEquals(CloseStatus.PROTOCOL_ERROR, session2.getCloseStatus());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue