OrderedMessageChannelDecorator doesn't preclude send limits

Closes gh-25581
This commit is contained in:
Rossen Stoyanchev 2020-08-28 20:40:15 +01:00
parent c4f4fbc003
commit bb941b6180
8 changed files with 417 additions and 84 deletions

View File

@ -142,7 +142,7 @@ public abstract class AbstractBrokerMessageHandler
* @since 5.1
*/
public void setPreservePublishOrder(boolean preservePublishOrder) {
OrderedMessageSender.configureOutboundChannel(this.clientOutboundChannel, preservePublishOrder);
OrderedMessageChannelDecorator.configureInterceptor(this.clientOutboundChannel, preservePublishOrder);
this.preservePublishOrder = preservePublishOrder;
}
@ -298,7 +298,7 @@ public abstract class AbstractBrokerMessageHandler
*/
protected MessageChannel getClientOutboundChannelForSession(String sessionId) {
return this.preservePublishOrder ?
new OrderedMessageSender(getClientOutboundChannel(), logger) : getClientOutboundChannel();
new OrderedMessageChannelDecorator(getClientOutboundChannel(), logger) : getClientOutboundChannel();
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -33,15 +33,17 @@ import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
/**
* Submit messages to an {@link ExecutorSubscribableChannel}, one at a time.
* The channel must have been configured with {@link #configureOutboundChannel}.
* Decorator for an {@link ExecutorSubscribableChannel} that ensures messages
* are processed in the order they were published to the channel. Messages are
* sent one at a time with the next one released when the prevoius has been
* processed. This decorator is intended to be applied per session.
*
* @author Rossen Stoyanchev
* @since 5.1
*/
class OrderedMessageSender implements MessageChannel {
public class OrderedMessageChannelDecorator implements MessageChannel {
static final String COMPLETION_TASK_HEADER = "simpSendCompletionTask";
private static final String NEXT_MESSAGE_TASK_HEADER = "simpNextMessageTask";
private final MessageChannel channel;
@ -53,7 +55,7 @@ class OrderedMessageSender implements MessageChannel {
private final AtomicBoolean sendInProgress = new AtomicBoolean(false);
public OrderedMessageSender(MessageChannel channel, Log logger) {
public OrderedMessageChannelDecorator(MessageChannel channel, Log logger) {
this.channel = channel;
this.logger = logger;
}
@ -84,10 +86,14 @@ class OrderedMessageSender implements MessageChannel {
private void sendNextMessage() {
for (;;) {
Message<?> message = this.messages.poll();
Message<?> message = this.messages.peek();
if (message != null) {
try {
addCompletionCallback(message);
addNextMessageTaskHeader(message, () -> {
if (removeMessage(message)) {
sendNextMessage();
}
});
if (this.channel.send(message)) {
return;
}
@ -97,9 +103,9 @@ class OrderedMessageSender implements MessageChannel {
logger.error("Failed to send " + message, ex);
}
}
removeMessage(message);
}
else {
// We ran out of messages..
this.sendInProgress.set(false);
trySend();
break;
@ -107,22 +113,40 @@ class OrderedMessageSender implements MessageChannel {
}
}
private void addCompletionCallback(Message<?> msg) {
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(msg, SimpMessageHeaderAccessor.class);
Assert.isTrue(accessor != null && accessor.isMutable(), "Expected mutable SimpMessageHeaderAccessor");
accessor.setHeader(COMPLETION_TASK_HEADER, (Runnable) this::sendNextMessage);
private boolean removeMessage(Message<?> message) {
Message<?> next = this.messages.peek();
if (next == message) {
this.messages.remove();
return true;
}
else {
return false;
}
}
private static void addNextMessageTaskHeader(Message<?> message, Runnable task) {
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class);
Assert.isTrue(accessor != null && accessor.isMutable(), "Expected mutable SimpMessageHeaderAccessor");
accessor.setHeader(NEXT_MESSAGE_TASK_HEADER, task);
}
/**
* Obtain the task to release the next message, if found.
*/
@Nullable
public static Runnable getNextMessageTask(Message<?> message) {
return (Runnable) message.getHeaders().get(OrderedMessageChannelDecorator.NEXT_MESSAGE_TASK_HEADER);
}
/**
* Install or remove an {@link ExecutorChannelInterceptor} that invokes a
* completion task once the message is handled.
* completion task, if found in the headers of the message.
* @param channel the channel to configure
* @param preservePublishOrder whether preserve order is on or off based on
* which an interceptor is either added or removed.
* @param preserveOrder whether preserve the order or publication; when
* "true" an interceptor is inserted, when "false" it removed.
*/
static void configureOutboundChannel(MessageChannel channel, boolean preservePublishOrder) {
if (preservePublishOrder) {
public static void configureInterceptor(MessageChannel channel, boolean preserveOrder) {
if (preserveOrder) {
Assert.isInstanceOf(ExecutorSubscribableChannel.class, channel,
"An ExecutorSubscribableChannel is required for `preservePublishOrder`");
ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel;
@ -133,8 +157,7 @@ class OrderedMessageSender implements MessageChannel {
else if (channel instanceof ExecutorSubscribableChannel) {
ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel;
execChannel.getInterceptors().stream().filter(i -> i instanceof CallbackInterceptor)
.findFirst()
.map(execChannel::removeInterceptor);
.findFirst().map(execChannel::removeInterceptor);
}
}
@ -144,9 +167,9 @@ class OrderedMessageSender implements MessageChannel {
@Override
public void afterMessageHandled(
Message<?> msg, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) {
Message<?> message, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) {
Runnable task = (Runnable) msg.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER);
Runnable task = getNextMessageTask(message);
if (task != null) {
task.run();
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,15 +36,16 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Unit tests for {@link OrderedMessageSender}.
* Unit tests for {@link OrderedMessageChannelDecorator}.
* @author Rossen Stoyanchev
* @see org.springframework.web.socket.messaging.OrderedMessageSendingIntegrationTests
*/
public class OrderedMessageSenderTests {
public class OrderedMessageChannelDecoratorTests {
private static final Log logger = LogFactory.getLog(OrderedMessageSenderTests.class);
private static final Log logger = LogFactory.getLog(OrderedMessageChannelDecoratorTests.class);
private OrderedMessageSender sender;
private OrderedMessageChannelDecorator sender;
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(this.executor);
@ -59,9 +60,9 @@ public class OrderedMessageSenderTests {
this.executor.afterPropertiesSet();
this.channel = new ExecutorSubscribableChannel(this.executor);
OrderedMessageSender.configureOutboundChannel(this.channel, true);
OrderedMessageChannelDecorator.configureInterceptor(this.channel, true);
this.sender = new OrderedMessageSender(this.channel, logger);
this.sender = new OrderedMessageChannelDecorator(this.channel, logger);
}
@ -89,9 +90,10 @@ public class OrderedMessageSenderTests {
latch.countDown();
return;
}
if (actual == 100 || actual == 200) {
// Force messages to queue up periodically
if (actual % 101 == 0) {
try {
Thread.sleep(200);
Thread.sleep(50);
}
catch (InterruptedException ex) {
result.set(ex.toString());

View File

@ -22,10 +22,12 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.lang.Nullable;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
@ -54,6 +56,10 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
private final OverflowStrategy overflowStrategy;
@Nullable
private Consumer<WebSocketMessage<?>> preSendCallback;
private final Queue<WebSocketMessage<?>> buffer = new LinkedBlockingQueue<>();
private final AtomicInteger bufferSize = new AtomicInteger();
@ -130,6 +136,15 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
return (start > 0 ? (System.currentTimeMillis() - start) : 0);
}
/**
* Set a callback invoked after a message is added to the send buffer.
* @param callback the callback to invoke
* @since 5.3
*/
public void setMessageCallback(Consumer<WebSocketMessage<?>> callback) {
this.preSendCallback = callback;
}
@Override
public void sendMessage(WebSocketMessage<?> message) throws IOException {
@ -140,6 +155,10 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
this.buffer.add(message);
this.bufferSize.addAndGet(message.getPayloadLength());
if (this.preSendCallback != null) {
this.preSendCallback.accept(message);
}
do {
if (!tryFlushMessageBuffer()) {
if (logger.isTraceEnabled()) {

View File

@ -39,6 +39,7 @@ import org.springframework.messaging.simp.SimpAttributes;
import org.springframework.messaging.simp.SimpAttributesContextHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompDecoder;
@ -57,6 +58,7 @@ import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.SessionLimitExceededException;
import org.springframework.web.socket.handler.WebSocketSessionDecorator;
import org.springframework.web.socket.sockjs.transport.SockJsSession;
@ -461,6 +463,13 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
payload = errorMessage.getPayload();
}
}
Runnable task = OrderedMessageChannelDecorator.getNextMessageTask(message);
if (task != null) {
Assert.isInstanceOf(ConcurrentWebSocketSessionDecorator.class, session);
((ConcurrentWebSocketSessionDecorator) session).setMessageCallback(m -> task.run());
}
sendToClient(session, accessor, payload);
}

View File

@ -0,0 +1,61 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.handler;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import org.springframework.web.socket.WebSocketMessage;
/**
* Blocks indefinitely on sending a message but provides a latch to notify when
* the message has been "sent" (i.e. session is blocked).
*
* @author Rossen Stoyanchev
*/
public class BlockingWebSocketSession extends TestWebSocketSession {
private final AtomicReference<CountDownLatch> sendLatch = new AtomicReference<>();
private final AtomicReference<CountDownLatch> releaseLatch = new AtomicReference<>();
public CountDownLatch initSendLatch() {
this.sendLatch.set(new CountDownLatch(1));
return this.sendLatch.get();
}
@Override
public void sendMessage(WebSocketMessage<?> message) throws IOException {
super.sendMessage(message);
if (this.sendLatch.get() != null) {
this.sendLatch.get().countDown();
}
block();
}
private void block() {
try {
this.releaseLatch.set(new CountDownLatch(1));
this.releaseLatch.get().await();
}
catch (InterruptedException ex) {
ex.printStackTrace();
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,13 +20,11 @@ import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.jupiter.api.Test;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator.OverflowStrategy;
@ -63,7 +61,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
@Test
public void sendAfterBlockedSend() throws IOException, InterruptedException {
BlockingSession session = new BlockingSession();
BlockingWebSocketSession session = new BlockingWebSocketSession();
session.setOpen(true);
ConcurrentWebSocketSessionDecorator decorator =
@ -85,9 +83,9 @@ public class ConcurrentWebSocketSessionDecoratorTests {
}
@Test
public void sendTimeLimitExceeded() throws IOException, InterruptedException {
public void sendTimeLimitExceeded() throws InterruptedException {
BlockingSession session = new BlockingSession();
BlockingWebSocketSession session = new BlockingWebSocketSession();
session.setId("123");
session.setOpen(true);
@ -109,7 +107,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
@Test
public void sendBufferSizeExceeded() throws IOException, InterruptedException {
BlockingSession session = new BlockingSession();
BlockingWebSocketSession session = new BlockingWebSocketSession();
session.setId("123");
session.setOpen(true);
@ -134,7 +132,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
@Test // SPR-17140
public void overflowStrategyDrop() throws IOException, InterruptedException {
BlockingSession session = new BlockingSession();
BlockingWebSocketSession session = new BlockingWebSocketSession();
session.setId("123");
session.setOpen(true);
@ -157,7 +155,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
@Test
public void closeStatusNormal() throws Exception {
BlockingSession session = new BlockingSession();
BlockingWebSocketSession session = new BlockingWebSocketSession();
session.setOpen(true);
WebSocketSession decorator = new ConcurrentWebSocketSessionDecorator(session, 10 * 1000, 1024);
@ -171,10 +169,10 @@ public class ConcurrentWebSocketSessionDecoratorTests {
@Test
public void closeStatusChangesToSessionNotReliable() throws Exception {
BlockingSession session = new BlockingSession();
BlockingWebSocketSession session = new BlockingWebSocketSession();
session.setId("123");
session.setOpen(true);
CountDownLatch sentMessageLatch = session.getSentMessageLatch();
CountDownLatch sentMessageLatch = session.initSendLatch();
int sendTimeLimit = 100;
int bufferSizeLimit = 1024;
@ -182,7 +180,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
ConcurrentWebSocketSessionDecorator decorator =
new ConcurrentWebSocketSessionDecorator(session, sendTimeLimit, bufferSizeLimit);
Executors.newSingleThreadExecutor().submit((Runnable) () -> {
Executors.newSingleThreadExecutor().submit(() -> {
TextMessage message = new TextMessage("slow message");
try {
decorator.sendMessage(message);
@ -199,12 +197,13 @@ public class ConcurrentWebSocketSessionDecoratorTests {
decorator.close(CloseStatus.PROTOCOL_ERROR);
assertThat(session.getCloseStatus()).as("CloseStatus should have changed to SESSION_NOT_RELIABLE").isEqualTo(CloseStatus.SESSION_NOT_RELIABLE);
assertThat(session.getCloseStatus())
.as("CloseStatus should have changed to SESSION_NOT_RELIABLE")
.isEqualTo(CloseStatus.SESSION_NOT_RELIABLE);
}
private void sendBlockingMessage(ConcurrentWebSocketSessionDecorator session) throws InterruptedException {
BlockingSession delegate = (BlockingSession) session.getDelegate();
CountDownLatch sentMessageLatch = delegate.getSentMessageLatch();
CountDownLatch latch = ((BlockingWebSocketSession) session.getDelegate()).initSendLatch();
Executors.newSingleThreadExecutor().submit(() -> {
TextMessage message = new TextMessage("slow message");
try {
@ -214,42 +213,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
e.printStackTrace();
}
});
assertThat(sentMessageLatch.await(5, TimeUnit.SECONDS)).isTrue();
}
private static class BlockingSession extends TestWebSocketSession {
private final AtomicReference<CountDownLatch> nextMessageLatch = new AtomicReference<>();
private final AtomicReference<CountDownLatch> releaseLatch = new AtomicReference<>();
public CountDownLatch getSentMessageLatch() {
this.nextMessageLatch.set(new CountDownLatch(1));
return this.nextMessageLatch.get();
}
@Override
public void sendMessage(WebSocketMessage<?> message) throws IOException {
super.sendMessage(message);
if (this.nextMessageLatch != null) {
this.nextMessageLatch.get().countDown();
}
block();
}
private void block() {
try {
this.releaseLatch.set(new CountDownLatch(1));
this.releaseLatch.get().await();
}
catch (InterruptedException e) {
e.printStackTrace();
}
}
assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue();
}
}

View File

@ -0,0 +1,255 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.messaging;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.BlockingWebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests to publish messages to an Executor backed channel wrapped with
* {@link OrderedMessageChannelDecorator} and handled by
* {@link StompSubProtocolHandler} delegating to a
* {@link ConcurrentWebSocketSessionDecorator} wrapped session.
*
* <p>The tests verify that:
* <ul>
* <li>messages are executed in the same order as they are published.
* <li>send buffer size and send time limits at the
* {@link ConcurrentWebSocketSessionDecorator} level are enforced.
* </ul>
*
* <p>The key is for {@link OrderedMessageChannelDecorator} to release the next
* message when after the current one is queued for sending, and not after it is
* sent, which may block and cause messages to accumulate in the
* {@link OrderedMessageChannelDecorator} instead of in
* {@link ConcurrentWebSocketSessionDecorator} where send limits are enforced.
*
* @author Rossen Stoyanchev
*/
public class OrderedMessageSendingIntegrationTests {
private static final Log logger = LogFactory.getLog(OrderedMessageSendingIntegrationTests.class);
private static final int MESSAGE_SIZE = new StompEncoder().encode(createMessage(0)).length;
private BlockingWebSocketSession blockingSession;
private ExecutorSubscribableChannel subscribableChannel;
private OrderedMessageChannelDecorator orderedMessageChannel;
private ThreadPoolTaskExecutor executor;
@BeforeEach
public void setup() {
this.blockingSession = new BlockingWebSocketSession();
this.blockingSession.setId("1");
this.blockingSession.setOpen(true);
this.executor = new ThreadPoolTaskExecutor();
this.executor.setCorePoolSize(Runtime.getRuntime().availableProcessors() * 2);
this.executor.setAllowCoreThreadTimeOut(true);
this.executor.afterPropertiesSet();
this.subscribableChannel = new ExecutorSubscribableChannel(this.executor);
OrderedMessageChannelDecorator.configureInterceptor(this.subscribableChannel, true);
this.orderedMessageChannel = new OrderedMessageChannelDecorator(this.subscribableChannel, logger);
}
@AfterEach
public void tearDown() {
this.executor.shutdown();
}
@Test
void sendAfterBlockedSend() throws InterruptedException {
int messageCount = 1000;
ConcurrentWebSocketSessionDecorator concurrentSessionDecorator =
new ConcurrentWebSocketSessionDecorator(
this.blockingSession, 60 * 1000, messageCount * MESSAGE_SIZE);
TestMessageHandler handler = new TestMessageHandler(concurrentSessionDecorator);
subscribableChannel.subscribe(handler);
List<Message<?>> expectedMessages = new ArrayList<>(messageCount);
// Send one to block
Message<byte[]> message = createMessage(0);
expectedMessages.add(message);
this.orderedMessageChannel.send(message);
CountDownLatch latch = new CountDownLatch(messageCount);
handler.setMessageLatch(latch);
for (int i = 1; i <= messageCount; i++) {
message = createMessage(i);
expectedMessages.add(message);
this.orderedMessageChannel.send(message);
}
latch.await(5, TimeUnit.SECONDS);
assertThat(concurrentSessionDecorator.getTimeSinceSendStarted() > 0).isTrue();
assertThat(concurrentSessionDecorator.getBufferSize()).isEqualTo((messageCount * MESSAGE_SIZE));
assertThat(handler.getSavedMessages()).containsExactlyElementsOf(expectedMessages);
assertThat(blockingSession.isOpen()).isTrue();
}
@Test
void exceedTimeLimit() throws InterruptedException {
ConcurrentWebSocketSessionDecorator concurrentSessionDecorator =
new ConcurrentWebSocketSessionDecorator(this.blockingSession, 100, 1024);
TestMessageHandler messageHandler = new TestMessageHandler(concurrentSessionDecorator);
subscribableChannel.subscribe(messageHandler);
// Send one to block
this.orderedMessageChannel.send(createMessage(0));
// Exceed send time..
Thread.sleep(200);
CountDownLatch messageLatch = new CountDownLatch(1);
messageHandler.setMessageLatch(messageLatch);
// Send one more
this.orderedMessageChannel.send(createMessage(1));
messageLatch.await(5, TimeUnit.SECONDS);
assertThat(messageHandler.getSavedException()).hasMessageMatching(
"Send time [\\d]+ \\(ms\\) for session '1' exceeded the allowed limit 100");
}
@Test
void exceedBufferSizeLimit() throws InterruptedException {
ConcurrentWebSocketSessionDecorator concurrentSessionDecorator =
new ConcurrentWebSocketSessionDecorator(this.blockingSession, 60 * 1000, 2 * MESSAGE_SIZE);
TestMessageHandler messageHandler = new TestMessageHandler(concurrentSessionDecorator);
subscribableChannel.subscribe(messageHandler);
// Send one to block
this.orderedMessageChannel.send(createMessage(0));
int messageCount = 3;
CountDownLatch messageLatch = new CountDownLatch(messageCount);
messageHandler.setMessageLatch(messageLatch);
for (int i = 1; i <= messageCount; i++) {
this.orderedMessageChannel.send(createMessage(i));
}
messageLatch.await(5, TimeUnit.SECONDS);
assertThat(messageHandler.getSavedException()).hasMessage(
"Buffer size " + 3 * MESSAGE_SIZE + " bytes for session '1' exceeds the allowed limit " + 2 * MESSAGE_SIZE);
}
private static Message<byte[]> createMessage(int index) {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.MESSAGE);
accessor.setHeader("index", index);
accessor.setSubscriptionId("1");
accessor.setLeaveMutable(true);
byte[] bytes = "payload".getBytes(StandardCharsets.UTF_8);
return MessageBuilder.createMessage(bytes, accessor.getMessageHeaders());
}
private static class TestMessageHandler implements MessageHandler {
private final StompSubProtocolHandler subProtocolHandler = new StompSubProtocolHandler();
private final WebSocketSession session;
@Nullable
private CountDownLatch messageLatch;
private Queue<Message<?>> messages = new LinkedBlockingQueue<>();
private AtomicReference<Exception> exception = new AtomicReference<>();
public TestMessageHandler(WebSocketSession session) {
this.session = session;
}
public void setMessageLatch(CountDownLatch latch) {
this.messageLatch = latch;
}
public Collection<Message<?>> getSavedMessages() {
return this.messages;
}
public Exception getSavedException() {
return this.exception.get();
}
@Override
public void handleMessage(Message<?> message) throws MessagingException {
this.messages.add(message);
try {
this.subProtocolHandler.handleMessageToClient(this.session, message);
}
catch (Exception ex) {
this.exception.set(ex);
}
if (this.messageLatch != null) {
this.messageLatch.countDown();
}
}
}
}