Add concurrent WebSocket session decorator (temp commit)

Issue: SPR-11586
This commit is contained in:
Rossen Stoyanchev 2014-03-20 16:37:59 -04:00
parent ac968e94ed
commit b7a974116e
6 changed files with 542 additions and 6 deletions

View File

@ -0,0 +1,137 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.handler;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import java.io.IOException;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
* Wraps a {@link org.springframework.web.socket.WebSocketSession} and guarantees
* only one thread can send messages at a time.
*
* <p>If a send is slow, subsequent attempts to send more messages from a different
* thread will fail to acquire the lock and the messages will be buffered instead --
* at that time the specified buffer size limit and send time limit will be checked
* and the session closed if the limits are exceeded.
*
* @author Rossen Stoyanchev
* @since 4.0.3
*/
public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorator {
private static Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class);
private final int sendTimeLimit;
private final int bufferSizeLimit;
private final Queue<WebSocketMessage<?>> buffer = new LinkedBlockingQueue<WebSocketMessage<?>>();
private final AtomicInteger bufferSize = new AtomicInteger();
private volatile long sendStartTime;
private final Lock lock = new ReentrantLock();
public ConcurrentWebSocketSessionDecorator(
WebSocketSession delegateSession, int sendTimeLimit, int bufferSizeLimit) {
super(delegateSession);
this.sendTimeLimit = sendTimeLimit;
this.bufferSizeLimit = bufferSizeLimit;
}
public int getBufferSize() {
return this.bufferSize.get();
}
public long getInProgressSendTime() {
long start = this.sendStartTime;
return (start > 0 ? (System.currentTimeMillis() - start) : 0);
}
public void sendMessage(WebSocketMessage<?> message) throws IOException {
this.buffer.add(message);
this.bufferSize.addAndGet(message.getPayloadLength());
do {
if (!tryFlushMessageBuffer()) {
checkSessionLimits();
break;
}
}
while (!this.buffer.isEmpty());
}
private boolean tryFlushMessageBuffer() throws IOException {
if (this.lock.tryLock()) {
try {
while (true) {
WebSocketMessage<?> messageToSend = this.buffer.poll();
if (messageToSend == null) {
break;
}
this.bufferSize.addAndGet(messageToSend.getPayloadLength() * -1);
this.sendStartTime = System.currentTimeMillis();
getDelegate().sendMessage(messageToSend);
this.sendStartTime = 0;
}
}
finally {
this.sendStartTime = 0;
lock.unlock();
}
return true;
}
return false;
}
private void checkSessionLimits() throws IOException {
if (getInProgressSendTime() > this.sendTimeLimit) {
logError("A message could not be sent due to a timeout");
getDelegate().close();
}
else if (this.bufferSize.get() > this.bufferSizeLimit) {
logError("The total send buffer byte count '" + this.bufferSize.get() +
"' for session '" + getId() + "' exceeds the allowed limit '" + this.bufferSizeLimit + "'");
getDelegate().close();
}
}
private void logError(String reason) {
logger.error(reason + ", number of buffered messages is '" + this.buffer.size() +
"', time since the last send started is '" + getInProgressSendTime() + "' (ms)");
}
}

View File

@ -23,6 +23,13 @@ import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* Wraps another {@link org.springframework.web.socket.WebSocketHandler}
* instance and delegates to it.
*
* <p>Also provides a {@link #getDelegate()} method to return the decorated
* handler as well as a {@link #getLastHandler()} method to go through all nested
* delegates and return the "last" handler.
*
* @author Rossen Stoyanchev
* @since 4.0
*/

View File

@ -0,0 +1,137 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.handler;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.util.List;
import java.util.Map;
/**
* Wraps another {@link org.springframework.web.socket.WebSocketSession} instance
* and delegates to it.
*
* <p>Also provides a {@link #getDelegate()} method to return the decorated session
* as well as a {@link #getLastSession()} method to go through all nested delegates
* and return the "last" session.
*
* @author Rossen Stoyanchev
* @since 4.0.3
*/
public class WebSocketSessionDecorator implements WebSocketSession {
private final WebSocketSession delegate;
public WebSocketSessionDecorator(WebSocketSession session) {
Assert.notNull(session, "Delegate WebSocketSessionSession is required");
this.delegate = session;
}
@Override
public String getId() {
return this.delegate.getId();
}
@Override
public URI getUri() {
return this.delegate.getUri();
}
@Override
public HttpHeaders getHandshakeHeaders() {
return this.delegate.getHandshakeHeaders();
}
@Override
public Map<String, Object> getAttributes() {
return this.delegate.getAttributes();
}
@Override
public Principal getPrincipal() {
return this.delegate.getPrincipal();
}
@Override
public InetSocketAddress getLocalAddress() {
return this.delegate.getLocalAddress();
}
@Override
public InetSocketAddress getRemoteAddress() {
return this.delegate.getRemoteAddress();
}
@Override
public String getAcceptedProtocol() {
return this.delegate.getAcceptedProtocol();
}
@Override
public List<WebSocketExtension> getExtensions() {
return this.delegate.getExtensions();
}
@Override
public boolean isOpen() {
return this.delegate.isOpen();
}
@Override
public void sendMessage(WebSocketMessage<?> message) throws IOException {
}
@Override
public void close() throws IOException {
this.delegate.close();
}
@Override
public void close(CloseStatus status) throws IOException {
this.delegate.close(status);
}
public WebSocketSession getDelegate() {
return this.delegate;
}
public WebSocketSession getLastSession() {
WebSocketSession result = this.delegate;
while (result instanceof WebSocketSessionDecorator) {
result = ((WebSocketSessionDecorator) result).getDelegate();
}
return result;
}
@Override
public String toString() {
return getClass().getSimpleName() + " [delegate=" + this.delegate + "]";
}
}

View File

@ -42,6 +42,7 @@ import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
/**
* An implementation of {@link WebSocketHandler} that delegates incoming WebSocket
@ -74,6 +75,10 @@ public class SubProtocolWebSocketHandler
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
private int sendTimeLimit = 20 * 1000;
private int sendBufferSizeLimit = 1024 * 1024;
private Object lifecycleMonitor = new Object();
private volatile boolean running = false;
@ -155,6 +160,24 @@ public class SubProtocolWebSocketHandler
return new ArrayList<String>(this.protocolHandlers.keySet());
}
public void setSendTimeLimit(int sendTimeLimit) {
this.sendTimeLimit = sendTimeLimit;
}
public int getSendTimeLimit() {
return this.sendTimeLimit;
}
public void setSendBufferSizeLimit(int sendBufferSizeLimit) {
this.sendBufferSizeLimit = sendBufferSizeLimit;
}
public int getSendBufferSizeLimit() {
return sendBufferSizeLimit;
}
@Override
public boolean isAutoStartup() {
return true;
@ -198,11 +221,15 @@ public class SubProtocolWebSocketHandler
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit());
this.sessions.put(session.getId(), session);
if (logger.isDebugEnabled()) {
logger.debug("Started WebSocket session=" + session.getId() +
", number of sessions=" + this.sessions.size());
}
findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel);
}

View File

@ -0,0 +1,220 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.handler;
import org.junit.Test;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/**
* Unit tests for
* {@link org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator}.
*
* @author Rossen Stoyanchev
*/
public class ConcurrentWebSocketSessionDecoratorTests {
@Test
public void send() throws IOException {
TestWebSocketSession session = new TestWebSocketSession();
session.setOpen(true);
ConcurrentWebSocketSessionDecorator concurrentSession =
new ConcurrentWebSocketSessionDecorator(session, 1000, 1024);
TextMessage textMessage = new TextMessage("payload");
concurrentSession.sendMessage(textMessage);
assertEquals(1, session.getSentMessages().size());
assertEquals(textMessage, session.getSentMessages().get(0));
assertEquals(0, concurrentSession.getBufferSize());
assertEquals(0, concurrentSession.getInProgressSendTime());
assertTrue(session.isOpen());
}
@Test
public void sendAfterBlockedSend() throws IOException, InterruptedException {
BlockingSession blockingSession = new BlockingSession();
blockingSession.setOpen(true);
CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch();
final ConcurrentWebSocketSessionDecorator concurrentSession =
new ConcurrentWebSocketSessionDecorator(blockingSession, 10 * 1000, 1024);
Executors.newSingleThreadExecutor().submit(new Runnable() {
@Override
public void run() {
TextMessage textMessage = new TextMessage("slow message");
try {
concurrentSession.sendMessage(textMessage);
}
catch (IOException e) {
e.printStackTrace();
}
}
});
assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS));
// ensure some send time elapses
Thread.sleep(100);
assertTrue(concurrentSession.getInProgressSendTime() > 0);
TextMessage payload = new TextMessage("payload");
for (int i=0; i < 5; i++) {
concurrentSession.sendMessage(payload);
}
assertTrue(concurrentSession.getInProgressSendTime() > 0);
assertEquals(5 * payload.getPayloadLength(), concurrentSession.getBufferSize());
assertTrue(blockingSession.isOpen());
}
@Test
public void sendTimeLimitExceeded() throws IOException, InterruptedException {
BlockingSession blockingSession = new BlockingSession();
blockingSession.setOpen(true);
CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch();
int sendTimeLimit = 100;
int bufferSizeLimit = 1024;
final ConcurrentWebSocketSessionDecorator concurrentSession =
new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit);
Executors.newSingleThreadExecutor().submit(new Runnable() {
@Override
public void run() {
TextMessage textMessage = new TextMessage("slow message");
try {
concurrentSession.sendMessage(textMessage);
}
catch (IOException e) {
e.printStackTrace();
}
}
});
assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS));
// ensure some send time elapses
Thread.sleep(sendTimeLimit + 100);
TextMessage payload = new TextMessage("payload");
concurrentSession.sendMessage(payload);
assertFalse(blockingSession.isOpen());
}
@Test
public void sendBufferSizeExceeded() throws IOException, InterruptedException {
BlockingSession blockingSession = new BlockingSession();
blockingSession.setOpen(true);
CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch();
int sendTimeLimit = 10 * 1000;
int bufferSizeLimit = 1024;
final ConcurrentWebSocketSessionDecorator concurrentSession =
new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit);
Executors.newSingleThreadExecutor().submit(new Runnable() {
@Override
public void run() {
TextMessage textMessage = new TextMessage("slow message");
try {
concurrentSession.sendMessage(textMessage);
}
catch (IOException e) {
e.printStackTrace();
}
}
});
assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS));
StringBuilder sb = new StringBuilder();
for (int i=0 ; i < 1023; i++) {
sb.append("a");
}
TextMessage message = new TextMessage(sb.toString());
concurrentSession.sendMessage(message);
assertEquals(1023, concurrentSession.getBufferSize());
assertTrue(blockingSession.isOpen());
concurrentSession.sendMessage(message);
assertFalse(blockingSession.isOpen());
}
private static class BlockingSession extends TestWebSocketSession {
private AtomicReference<CountDownLatch> nextMessageLatch = new AtomicReference<>();
private AtomicReference<CountDownLatch> releaseLatch = new AtomicReference<>();
public CountDownLatch getSentMessageLatch() {
this.nextMessageLatch.set(new CountDownLatch(1));
return this.nextMessageLatch.get();
}
@Override
public void sendMessage(WebSocketMessage<?> message) throws IOException {
super.sendMessage(message);
if (this.nextMessageLatch != null) {
this.nextMessageLatch.get().countDown();
}
block();
}
private void block() {
try {
this.releaseLatch.set(new CountDownLatch(1));
this.releaseLatch.get().await();
}
catch (InterruptedException e) {
e.printStackTrace();
}
}
public void release() {
if (this.releaseLatch.get() != null) {
this.releaseLatch.get().countDown();
}
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,9 +21,12 @@ import java.util.Arrays;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.TestWebSocketSession;
import static org.mockito.Mockito.*;
@ -71,7 +74,8 @@ public class SubProtocolWebSocketHandlerTests {
this.session.setAcceptedProtocol("v12.sToMp");
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel);
verify(this.stompHandler).afterSessionStarted(
isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel));
verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel);
}
@ -81,7 +85,8 @@ public class SubProtocolWebSocketHandlerTests {
this.session.setAcceptedProtocol("v12.sToMp");
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel);
verify(this.stompHandler).afterSessionStarted(
isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel));
}
@Test(expected=IllegalStateException.class)
@ -98,7 +103,8 @@ public class SubProtocolWebSocketHandlerTests {
this.webSocketHandler.setDefaultProtocolHandler(defaultHandler);
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.defaultHandler).afterSessionStarted(session, this.inClientChannel);
verify(this.defaultHandler).afterSessionStarted(
isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel));
verify(this.stompHandler, times(0)).afterSessionStarted(session, this.inClientChannel);
verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel);
}
@ -109,7 +115,8 @@ public class SubProtocolWebSocketHandlerTests {
this.webSocketHandler.setDefaultProtocolHandler(defaultHandler);
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.defaultHandler).afterSessionStarted(session, this.inClientChannel);
verify(this.defaultHandler).afterSessionStarted(
isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel));
verify(this.stompHandler, times(0)).afterSessionStarted(session, this.inClientChannel);
verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel);
}
@ -119,7 +126,8 @@ public class SubProtocolWebSocketHandlerTests {
this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler));
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel);
verify(this.stompHandler).afterSessionStarted(
isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel));
}
@Test(expected=IllegalStateException.class)