OrderedMessageChannelDecorator doesn't preclude send limits
Closes gh-25581
This commit is contained in:
parent
c4f4fbc003
commit
bb941b6180
|
|
@ -142,7 +142,7 @@ public abstract class AbstractBrokerMessageHandler
|
|||
* @since 5.1
|
||||
*/
|
||||
public void setPreservePublishOrder(boolean preservePublishOrder) {
|
||||
OrderedMessageSender.configureOutboundChannel(this.clientOutboundChannel, preservePublishOrder);
|
||||
OrderedMessageChannelDecorator.configureInterceptor(this.clientOutboundChannel, preservePublishOrder);
|
||||
this.preservePublishOrder = preservePublishOrder;
|
||||
}
|
||||
|
||||
|
|
@ -298,7 +298,7 @@ public abstract class AbstractBrokerMessageHandler
|
|||
*/
|
||||
protected MessageChannel getClientOutboundChannelForSession(String sessionId) {
|
||||
return this.preservePublishOrder ?
|
||||
new OrderedMessageSender(getClientOutboundChannel(), logger) : getClientOutboundChannel();
|
||||
new OrderedMessageChannelDecorator(getClientOutboundChannel(), logger) : getClientOutboundChannel();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
@ -33,15 +33,17 @@ import org.springframework.messaging.support.MessageHeaderAccessor;
|
|||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* Submit messages to an {@link ExecutorSubscribableChannel}, one at a time.
|
||||
* The channel must have been configured with {@link #configureOutboundChannel}.
|
||||
* Decorator for an {@link ExecutorSubscribableChannel} that ensures messages
|
||||
* are processed in the order they were published to the channel. Messages are
|
||||
* sent one at a time with the next one released when the prevoius has been
|
||||
* processed. This decorator is intended to be applied per session.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 5.1
|
||||
*/
|
||||
class OrderedMessageSender implements MessageChannel {
|
||||
public class OrderedMessageChannelDecorator implements MessageChannel {
|
||||
|
||||
static final String COMPLETION_TASK_HEADER = "simpSendCompletionTask";
|
||||
private static final String NEXT_MESSAGE_TASK_HEADER = "simpNextMessageTask";
|
||||
|
||||
|
||||
private final MessageChannel channel;
|
||||
|
|
@ -53,7 +55,7 @@ class OrderedMessageSender implements MessageChannel {
|
|||
private final AtomicBoolean sendInProgress = new AtomicBoolean(false);
|
||||
|
||||
|
||||
public OrderedMessageSender(MessageChannel channel, Log logger) {
|
||||
public OrderedMessageChannelDecorator(MessageChannel channel, Log logger) {
|
||||
this.channel = channel;
|
||||
this.logger = logger;
|
||||
}
|
||||
|
|
@ -84,10 +86,14 @@ class OrderedMessageSender implements MessageChannel {
|
|||
|
||||
private void sendNextMessage() {
|
||||
for (;;) {
|
||||
Message<?> message = this.messages.poll();
|
||||
Message<?> message = this.messages.peek();
|
||||
if (message != null) {
|
||||
try {
|
||||
addCompletionCallback(message);
|
||||
addNextMessageTaskHeader(message, () -> {
|
||||
if (removeMessage(message)) {
|
||||
sendNextMessage();
|
||||
}
|
||||
});
|
||||
if (this.channel.send(message)) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -97,9 +103,9 @@ class OrderedMessageSender implements MessageChannel {
|
|||
logger.error("Failed to send " + message, ex);
|
||||
}
|
||||
}
|
||||
removeMessage(message);
|
||||
}
|
||||
else {
|
||||
// We ran out of messages..
|
||||
this.sendInProgress.set(false);
|
||||
trySend();
|
||||
break;
|
||||
|
|
@ -107,22 +113,40 @@ class OrderedMessageSender implements MessageChannel {
|
|||
}
|
||||
}
|
||||
|
||||
private void addCompletionCallback(Message<?> msg) {
|
||||
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(msg, SimpMessageHeaderAccessor.class);
|
||||
Assert.isTrue(accessor != null && accessor.isMutable(), "Expected mutable SimpMessageHeaderAccessor");
|
||||
accessor.setHeader(COMPLETION_TASK_HEADER, (Runnable) this::sendNextMessage);
|
||||
private boolean removeMessage(Message<?> message) {
|
||||
Message<?> next = this.messages.peek();
|
||||
if (next == message) {
|
||||
this.messages.remove();
|
||||
return true;
|
||||
}
|
||||
else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static void addNextMessageTaskHeader(Message<?> message, Runnable task) {
|
||||
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class);
|
||||
Assert.isTrue(accessor != null && accessor.isMutable(), "Expected mutable SimpMessageHeaderAccessor");
|
||||
accessor.setHeader(NEXT_MESSAGE_TASK_HEADER, task);
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain the task to release the next message, if found.
|
||||
*/
|
||||
@Nullable
|
||||
public static Runnable getNextMessageTask(Message<?> message) {
|
||||
return (Runnable) message.getHeaders().get(OrderedMessageChannelDecorator.NEXT_MESSAGE_TASK_HEADER);
|
||||
}
|
||||
|
||||
/**
|
||||
* Install or remove an {@link ExecutorChannelInterceptor} that invokes a
|
||||
* completion task once the message is handled.
|
||||
* completion task, if found in the headers of the message.
|
||||
* @param channel the channel to configure
|
||||
* @param preservePublishOrder whether preserve order is on or off based on
|
||||
* which an interceptor is either added or removed.
|
||||
* @param preserveOrder whether preserve the order or publication; when
|
||||
* "true" an interceptor is inserted, when "false" it removed.
|
||||
*/
|
||||
static void configureOutboundChannel(MessageChannel channel, boolean preservePublishOrder) {
|
||||
if (preservePublishOrder) {
|
||||
public static void configureInterceptor(MessageChannel channel, boolean preserveOrder) {
|
||||
if (preserveOrder) {
|
||||
Assert.isInstanceOf(ExecutorSubscribableChannel.class, channel,
|
||||
"An ExecutorSubscribableChannel is required for `preservePublishOrder`");
|
||||
ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel;
|
||||
|
|
@ -133,8 +157,7 @@ class OrderedMessageSender implements MessageChannel {
|
|||
else if (channel instanceof ExecutorSubscribableChannel) {
|
||||
ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel;
|
||||
execChannel.getInterceptors().stream().filter(i -> i instanceof CallbackInterceptor)
|
||||
.findFirst()
|
||||
.map(execChannel::removeInterceptor);
|
||||
.findFirst().map(execChannel::removeInterceptor);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
@ -144,9 +167,9 @@ class OrderedMessageSender implements MessageChannel {
|
|||
|
||||
@Override
|
||||
public void afterMessageHandled(
|
||||
Message<?> msg, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) {
|
||||
Message<?> message, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) {
|
||||
|
||||
Runnable task = (Runnable) msg.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER);
|
||||
Runnable task = getNextMessageTask(message);
|
||||
if (task != null) {
|
||||
task.run();
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
@ -36,15 +36,16 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
|||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link OrderedMessageSender}.
|
||||
* Unit tests for {@link OrderedMessageChannelDecorator}.
|
||||
* @author Rossen Stoyanchev
|
||||
* @see org.springframework.web.socket.messaging.OrderedMessageSendingIntegrationTests
|
||||
*/
|
||||
public class OrderedMessageSenderTests {
|
||||
public class OrderedMessageChannelDecoratorTests {
|
||||
|
||||
private static final Log logger = LogFactory.getLog(OrderedMessageSenderTests.class);
|
||||
private static final Log logger = LogFactory.getLog(OrderedMessageChannelDecoratorTests.class);
|
||||
|
||||
|
||||
private OrderedMessageSender sender;
|
||||
private OrderedMessageChannelDecorator sender;
|
||||
|
||||
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(this.executor);
|
||||
|
||||
|
|
@ -59,9 +60,9 @@ public class OrderedMessageSenderTests {
|
|||
this.executor.afterPropertiesSet();
|
||||
|
||||
this.channel = new ExecutorSubscribableChannel(this.executor);
|
||||
OrderedMessageSender.configureOutboundChannel(this.channel, true);
|
||||
OrderedMessageChannelDecorator.configureInterceptor(this.channel, true);
|
||||
|
||||
this.sender = new OrderedMessageSender(this.channel, logger);
|
||||
this.sender = new OrderedMessageChannelDecorator(this.channel, logger);
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -89,9 +90,10 @@ public class OrderedMessageSenderTests {
|
|||
latch.countDown();
|
||||
return;
|
||||
}
|
||||
if (actual == 100 || actual == 200) {
|
||||
// Force messages to queue up periodically
|
||||
if (actual % 101 == 0) {
|
||||
try {
|
||||
Thread.sleep(200);
|
||||
Thread.sleep(50);
|
||||
}
|
||||
catch (InterruptedException ex) {
|
||||
result.set(ex.toString());
|
||||
|
|
@ -22,10 +22,12 @@ import java.util.concurrent.LinkedBlockingQueue;
|
|||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.web.socket.CloseStatus;
|
||||
import org.springframework.web.socket.WebSocketMessage;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
|
|
@ -54,6 +56,10 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
|
|||
|
||||
private final OverflowStrategy overflowStrategy;
|
||||
|
||||
@Nullable
|
||||
private Consumer<WebSocketMessage<?>> preSendCallback;
|
||||
|
||||
|
||||
private final Queue<WebSocketMessage<?>> buffer = new LinkedBlockingQueue<>();
|
||||
|
||||
private final AtomicInteger bufferSize = new AtomicInteger();
|
||||
|
|
@ -130,6 +136,15 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
|
|||
return (start > 0 ? (System.currentTimeMillis() - start) : 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a callback invoked after a message is added to the send buffer.
|
||||
* @param callback the callback to invoke
|
||||
* @since 5.3
|
||||
*/
|
||||
public void setMessageCallback(Consumer<WebSocketMessage<?>> callback) {
|
||||
this.preSendCallback = callback;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void sendMessage(WebSocketMessage<?> message) throws IOException {
|
||||
|
|
@ -140,6 +155,10 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
|
|||
this.buffer.add(message);
|
||||
this.bufferSize.addAndGet(message.getPayloadLength());
|
||||
|
||||
if (this.preSendCallback != null) {
|
||||
this.preSendCallback.accept(message);
|
||||
}
|
||||
|
||||
do {
|
||||
if (!tryFlushMessageBuffer()) {
|
||||
if (logger.isTraceEnabled()) {
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ import org.springframework.messaging.simp.SimpAttributes;
|
|||
import org.springframework.messaging.simp.SimpAttributesContextHolder;
|
||||
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
|
||||
import org.springframework.messaging.simp.SimpMessageType;
|
||||
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
|
||||
import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
|
||||
import org.springframework.messaging.simp.stomp.StompCommand;
|
||||
import org.springframework.messaging.simp.stomp.StompDecoder;
|
||||
|
|
@ -57,6 +58,7 @@ import org.springframework.web.socket.CloseStatus;
|
|||
import org.springframework.web.socket.TextMessage;
|
||||
import org.springframework.web.socket.WebSocketMessage;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
|
||||
import org.springframework.web.socket.handler.SessionLimitExceededException;
|
||||
import org.springframework.web.socket.handler.WebSocketSessionDecorator;
|
||||
import org.springframework.web.socket.sockjs.transport.SockJsSession;
|
||||
|
|
@ -461,6 +463,13 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
|
|||
payload = errorMessage.getPayload();
|
||||
}
|
||||
}
|
||||
|
||||
Runnable task = OrderedMessageChannelDecorator.getNextMessageTask(message);
|
||||
if (task != null) {
|
||||
Assert.isInstanceOf(ConcurrentWebSocketSessionDecorator.class, session);
|
||||
((ConcurrentWebSocketSessionDecorator) session).setMessageCallback(m -> task.run());
|
||||
}
|
||||
|
||||
sendToClient(session, accessor, payload);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.web.socket.handler;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import org.springframework.web.socket.WebSocketMessage;
|
||||
|
||||
/**
|
||||
* Blocks indefinitely on sending a message but provides a latch to notify when
|
||||
* the message has been "sent" (i.e. session is blocked).
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
*/
|
||||
public class BlockingWebSocketSession extends TestWebSocketSession {
|
||||
|
||||
private final AtomicReference<CountDownLatch> sendLatch = new AtomicReference<>();
|
||||
|
||||
private final AtomicReference<CountDownLatch> releaseLatch = new AtomicReference<>();
|
||||
|
||||
|
||||
public CountDownLatch initSendLatch() {
|
||||
this.sendLatch.set(new CountDownLatch(1));
|
||||
return this.sendLatch.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendMessage(WebSocketMessage<?> message) throws IOException {
|
||||
super.sendMessage(message);
|
||||
if (this.sendLatch.get() != null) {
|
||||
this.sendLatch.get().countDown();
|
||||
}
|
||||
block();
|
||||
}
|
||||
|
||||
private void block() {
|
||||
try {
|
||||
this.releaseLatch.set(new CountDownLatch(1));
|
||||
this.releaseLatch.get().await();
|
||||
}
|
||||
catch (InterruptedException ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
@ -20,13 +20,11 @@ import java.io.IOException;
|
|||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.web.socket.CloseStatus;
|
||||
import org.springframework.web.socket.TextMessage;
|
||||
import org.springframework.web.socket.WebSocketMessage;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator.OverflowStrategy;
|
||||
|
||||
|
|
@ -63,7 +61,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
|
|||
@Test
|
||||
public void sendAfterBlockedSend() throws IOException, InterruptedException {
|
||||
|
||||
BlockingSession session = new BlockingSession();
|
||||
BlockingWebSocketSession session = new BlockingWebSocketSession();
|
||||
session.setOpen(true);
|
||||
|
||||
ConcurrentWebSocketSessionDecorator decorator =
|
||||
|
|
@ -85,9 +83,9 @@ public class ConcurrentWebSocketSessionDecoratorTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void sendTimeLimitExceeded() throws IOException, InterruptedException {
|
||||
public void sendTimeLimitExceeded() throws InterruptedException {
|
||||
|
||||
BlockingSession session = new BlockingSession();
|
||||
BlockingWebSocketSession session = new BlockingWebSocketSession();
|
||||
session.setId("123");
|
||||
session.setOpen(true);
|
||||
|
||||
|
|
@ -109,7 +107,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
|
|||
@Test
|
||||
public void sendBufferSizeExceeded() throws IOException, InterruptedException {
|
||||
|
||||
BlockingSession session = new BlockingSession();
|
||||
BlockingWebSocketSession session = new BlockingWebSocketSession();
|
||||
session.setId("123");
|
||||
session.setOpen(true);
|
||||
|
||||
|
|
@ -134,7 +132,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
|
|||
@Test // SPR-17140
|
||||
public void overflowStrategyDrop() throws IOException, InterruptedException {
|
||||
|
||||
BlockingSession session = new BlockingSession();
|
||||
BlockingWebSocketSession session = new BlockingWebSocketSession();
|
||||
session.setId("123");
|
||||
session.setOpen(true);
|
||||
|
||||
|
|
@ -157,7 +155,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
|
|||
@Test
|
||||
public void closeStatusNormal() throws Exception {
|
||||
|
||||
BlockingSession session = new BlockingSession();
|
||||
BlockingWebSocketSession session = new BlockingWebSocketSession();
|
||||
session.setOpen(true);
|
||||
WebSocketSession decorator = new ConcurrentWebSocketSessionDecorator(session, 10 * 1000, 1024);
|
||||
|
||||
|
|
@ -171,10 +169,10 @@ public class ConcurrentWebSocketSessionDecoratorTests {
|
|||
@Test
|
||||
public void closeStatusChangesToSessionNotReliable() throws Exception {
|
||||
|
||||
BlockingSession session = new BlockingSession();
|
||||
BlockingWebSocketSession session = new BlockingWebSocketSession();
|
||||
session.setId("123");
|
||||
session.setOpen(true);
|
||||
CountDownLatch sentMessageLatch = session.getSentMessageLatch();
|
||||
CountDownLatch sentMessageLatch = session.initSendLatch();
|
||||
|
||||
int sendTimeLimit = 100;
|
||||
int bufferSizeLimit = 1024;
|
||||
|
|
@ -182,7 +180,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
|
|||
ConcurrentWebSocketSessionDecorator decorator =
|
||||
new ConcurrentWebSocketSessionDecorator(session, sendTimeLimit, bufferSizeLimit);
|
||||
|
||||
Executors.newSingleThreadExecutor().submit((Runnable) () -> {
|
||||
Executors.newSingleThreadExecutor().submit(() -> {
|
||||
TextMessage message = new TextMessage("slow message");
|
||||
try {
|
||||
decorator.sendMessage(message);
|
||||
|
|
@ -199,12 +197,13 @@ public class ConcurrentWebSocketSessionDecoratorTests {
|
|||
|
||||
decorator.close(CloseStatus.PROTOCOL_ERROR);
|
||||
|
||||
assertThat(session.getCloseStatus()).as("CloseStatus should have changed to SESSION_NOT_RELIABLE").isEqualTo(CloseStatus.SESSION_NOT_RELIABLE);
|
||||
assertThat(session.getCloseStatus())
|
||||
.as("CloseStatus should have changed to SESSION_NOT_RELIABLE")
|
||||
.isEqualTo(CloseStatus.SESSION_NOT_RELIABLE);
|
||||
}
|
||||
|
||||
private void sendBlockingMessage(ConcurrentWebSocketSessionDecorator session) throws InterruptedException {
|
||||
BlockingSession delegate = (BlockingSession) session.getDelegate();
|
||||
CountDownLatch sentMessageLatch = delegate.getSentMessageLatch();
|
||||
CountDownLatch latch = ((BlockingWebSocketSession) session.getDelegate()).initSendLatch();
|
||||
Executors.newSingleThreadExecutor().submit(() -> {
|
||||
TextMessage message = new TextMessage("slow message");
|
||||
try {
|
||||
|
|
@ -214,42 +213,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
|
|||
e.printStackTrace();
|
||||
}
|
||||
});
|
||||
assertThat(sentMessageLatch.await(5, TimeUnit.SECONDS)).isTrue();
|
||||
}
|
||||
|
||||
|
||||
|
||||
private static class BlockingSession extends TestWebSocketSession {
|
||||
|
||||
private final AtomicReference<CountDownLatch> nextMessageLatch = new AtomicReference<>();
|
||||
|
||||
private final AtomicReference<CountDownLatch> releaseLatch = new AtomicReference<>();
|
||||
|
||||
|
||||
public CountDownLatch getSentMessageLatch() {
|
||||
this.nextMessageLatch.set(new CountDownLatch(1));
|
||||
return this.nextMessageLatch.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sendMessage(WebSocketMessage<?> message) throws IOException {
|
||||
super.sendMessage(message);
|
||||
if (this.nextMessageLatch != null) {
|
||||
this.nextMessageLatch.get().countDown();
|
||||
}
|
||||
block();
|
||||
}
|
||||
|
||||
private void block() {
|
||||
try {
|
||||
this.releaseLatch.set(new CountDownLatch(1));
|
||||
this.releaseLatch.get().await();
|
||||
}
|
||||
catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,255 @@
|
|||
/*
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.web.socket.messaging;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Queue;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageHandler;
|
||||
import org.springframework.messaging.MessagingException;
|
||||
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
|
||||
import org.springframework.messaging.simp.stomp.StompCommand;
|
||||
import org.springframework.messaging.simp.stomp.StompEncoder;
|
||||
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
|
||||
import org.springframework.messaging.support.ExecutorSubscribableChannel;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.handler.BlockingWebSocketSession;
|
||||
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* Tests to publish messages to an Executor backed channel wrapped with
|
||||
* {@link OrderedMessageChannelDecorator} and handled by
|
||||
* {@link StompSubProtocolHandler} delegating to a
|
||||
* {@link ConcurrentWebSocketSessionDecorator} wrapped session.
|
||||
*
|
||||
* <p>The tests verify that:
|
||||
* <ul>
|
||||
* <li>messages are executed in the same order as they are published.
|
||||
* <li>send buffer size and send time limits at the
|
||||
* {@link ConcurrentWebSocketSessionDecorator} level are enforced.
|
||||
* </ul>
|
||||
*
|
||||
* <p>The key is for {@link OrderedMessageChannelDecorator} to release the next
|
||||
* message when after the current one is queued for sending, and not after it is
|
||||
* sent, which may block and cause messages to accumulate in the
|
||||
* {@link OrderedMessageChannelDecorator} instead of in
|
||||
* {@link ConcurrentWebSocketSessionDecorator} where send limits are enforced.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
*/
|
||||
public class OrderedMessageSendingIntegrationTests {
|
||||
|
||||
private static final Log logger = LogFactory.getLog(OrderedMessageSendingIntegrationTests.class);
|
||||
|
||||
private static final int MESSAGE_SIZE = new StompEncoder().encode(createMessage(0)).length;
|
||||
|
||||
|
||||
private BlockingWebSocketSession blockingSession;
|
||||
|
||||
private ExecutorSubscribableChannel subscribableChannel;
|
||||
|
||||
private OrderedMessageChannelDecorator orderedMessageChannel;
|
||||
|
||||
private ThreadPoolTaskExecutor executor;
|
||||
|
||||
|
||||
|
||||
@BeforeEach
|
||||
public void setup() {
|
||||
this.blockingSession = new BlockingWebSocketSession();
|
||||
this.blockingSession.setId("1");
|
||||
this.blockingSession.setOpen(true);
|
||||
|
||||
this.executor = new ThreadPoolTaskExecutor();
|
||||
this.executor.setCorePoolSize(Runtime.getRuntime().availableProcessors() * 2);
|
||||
this.executor.setAllowCoreThreadTimeOut(true);
|
||||
this.executor.afterPropertiesSet();
|
||||
|
||||
this.subscribableChannel = new ExecutorSubscribableChannel(this.executor);
|
||||
OrderedMessageChannelDecorator.configureInterceptor(this.subscribableChannel, true);
|
||||
|
||||
this.orderedMessageChannel = new OrderedMessageChannelDecorator(this.subscribableChannel, logger);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void tearDown() {
|
||||
this.executor.shutdown();
|
||||
}
|
||||
|
||||
@Test
|
||||
void sendAfterBlockedSend() throws InterruptedException {
|
||||
|
||||
int messageCount = 1000;
|
||||
|
||||
ConcurrentWebSocketSessionDecorator concurrentSessionDecorator =
|
||||
new ConcurrentWebSocketSessionDecorator(
|
||||
this.blockingSession, 60 * 1000, messageCount * MESSAGE_SIZE);
|
||||
|
||||
TestMessageHandler handler = new TestMessageHandler(concurrentSessionDecorator);
|
||||
subscribableChannel.subscribe(handler);
|
||||
|
||||
List<Message<?>> expectedMessages = new ArrayList<>(messageCount);
|
||||
|
||||
// Send one to block
|
||||
Message<byte[]> message = createMessage(0);
|
||||
expectedMessages.add(message);
|
||||
this.orderedMessageChannel.send(message);
|
||||
|
||||
CountDownLatch latch = new CountDownLatch(messageCount);
|
||||
handler.setMessageLatch(latch);
|
||||
|
||||
for (int i = 1; i <= messageCount; i++) {
|
||||
message = createMessage(i);
|
||||
expectedMessages.add(message);
|
||||
this.orderedMessageChannel.send(message);
|
||||
}
|
||||
|
||||
latch.await(5, TimeUnit.SECONDS);
|
||||
|
||||
assertThat(concurrentSessionDecorator.getTimeSinceSendStarted() > 0).isTrue();
|
||||
assertThat(concurrentSessionDecorator.getBufferSize()).isEqualTo((messageCount * MESSAGE_SIZE));
|
||||
assertThat(handler.getSavedMessages()).containsExactlyElementsOf(expectedMessages);
|
||||
assertThat(blockingSession.isOpen()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void exceedTimeLimit() throws InterruptedException {
|
||||
|
||||
ConcurrentWebSocketSessionDecorator concurrentSessionDecorator =
|
||||
new ConcurrentWebSocketSessionDecorator(this.blockingSession, 100, 1024);
|
||||
|
||||
TestMessageHandler messageHandler = new TestMessageHandler(concurrentSessionDecorator);
|
||||
subscribableChannel.subscribe(messageHandler);
|
||||
|
||||
// Send one to block
|
||||
this.orderedMessageChannel.send(createMessage(0));
|
||||
|
||||
// Exceed send time..
|
||||
Thread.sleep(200);
|
||||
|
||||
CountDownLatch messageLatch = new CountDownLatch(1);
|
||||
messageHandler.setMessageLatch(messageLatch);
|
||||
|
||||
// Send one more
|
||||
this.orderedMessageChannel.send(createMessage(1));
|
||||
|
||||
messageLatch.await(5, TimeUnit.SECONDS);
|
||||
|
||||
assertThat(messageHandler.getSavedException()).hasMessageMatching(
|
||||
"Send time [\\d]+ \\(ms\\) for session '1' exceeded the allowed limit 100");
|
||||
}
|
||||
|
||||
@Test
|
||||
void exceedBufferSizeLimit() throws InterruptedException {
|
||||
|
||||
ConcurrentWebSocketSessionDecorator concurrentSessionDecorator =
|
||||
new ConcurrentWebSocketSessionDecorator(this.blockingSession, 60 * 1000, 2 * MESSAGE_SIZE);
|
||||
|
||||
TestMessageHandler messageHandler = new TestMessageHandler(concurrentSessionDecorator);
|
||||
subscribableChannel.subscribe(messageHandler);
|
||||
|
||||
// Send one to block
|
||||
this.orderedMessageChannel.send(createMessage(0));
|
||||
|
||||
int messageCount = 3;
|
||||
CountDownLatch messageLatch = new CountDownLatch(messageCount);
|
||||
messageHandler.setMessageLatch(messageLatch);
|
||||
|
||||
for (int i = 1; i <= messageCount; i++) {
|
||||
this.orderedMessageChannel.send(createMessage(i));
|
||||
}
|
||||
|
||||
messageLatch.await(5, TimeUnit.SECONDS);
|
||||
|
||||
assertThat(messageHandler.getSavedException()).hasMessage(
|
||||
"Buffer size " + 3 * MESSAGE_SIZE + " bytes for session '1' exceeds the allowed limit " + 2 * MESSAGE_SIZE);
|
||||
}
|
||||
|
||||
private static Message<byte[]> createMessage(int index) {
|
||||
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.MESSAGE);
|
||||
accessor.setHeader("index", index);
|
||||
accessor.setSubscriptionId("1");
|
||||
accessor.setLeaveMutable(true);
|
||||
byte[] bytes = "payload".getBytes(StandardCharsets.UTF_8);
|
||||
return MessageBuilder.createMessage(bytes, accessor.getMessageHeaders());
|
||||
|
||||
}
|
||||
|
||||
|
||||
private static class TestMessageHandler implements MessageHandler {
|
||||
|
||||
private final StompSubProtocolHandler subProtocolHandler = new StompSubProtocolHandler();
|
||||
|
||||
private final WebSocketSession session;
|
||||
|
||||
@Nullable
|
||||
private CountDownLatch messageLatch;
|
||||
|
||||
private Queue<Message<?>> messages = new LinkedBlockingQueue<>();
|
||||
|
||||
private AtomicReference<Exception> exception = new AtomicReference<>();
|
||||
|
||||
|
||||
public TestMessageHandler(WebSocketSession session) {
|
||||
this.session = session;
|
||||
}
|
||||
|
||||
public void setMessageLatch(CountDownLatch latch) {
|
||||
this.messageLatch = latch;
|
||||
}
|
||||
|
||||
public Collection<Message<?>> getSavedMessages() {
|
||||
return this.messages;
|
||||
}
|
||||
|
||||
public Exception getSavedException() {
|
||||
return this.exception.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleMessage(Message<?> message) throws MessagingException {
|
||||
this.messages.add(message);
|
||||
try {
|
||||
this.subProtocolHandler.handleMessageToClient(this.session, message);
|
||||
}
|
||||
catch (Exception ex) {
|
||||
this.exception.set(ex);
|
||||
}
|
||||
if (this.messageLatch != null) {
|
||||
this.messageLatch.countDown();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue