OrderedMessageSender throughput improvement

Before this change messages were sent serially across sessions but
ordering is important only within a session. This leads to head of
line blocking when a session is slow to send, and also enforcement of
send buffer size and time limits is precluded because it happens at
a lower level in the transport.

This change ensures messages are held up only if there is another
from the same session is being sent. This allows messages from each
session to flow independent of other.

See gh-25581
This commit is contained in:
Rossen Stoyanchev 2020-08-27 10:50:05 +01:00
parent 568b44eb9d
commit f5c287a6e6
2 changed files with 204 additions and 61 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,8 +16,13 @@
package org.springframework.messaging.simp.broker; package org.springframework.messaging.simp.broker;
import java.util.Collection;
import java.util.HashSet;
import java.util.Queue; import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
@ -33,8 +38,14 @@ import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
* Submit messages to an {@link ExecutorSubscribableChannel}, one at a time. * {@code MessageChannel} decorator that ensures messages from the same session
* The channel must have been configured with {@link #configureOutboundChannel}. * are sent and processed in the same order. This would not normally be the case
* with an {@code Executor} backed {@code MessageChannel} since the executor
* is free to submit tasks in any order.
*
* <p>To provide ordering, inbound messages are placed in a queue and sent one
* one at a time per session. Once a message is processed, a callback is used to
* notify that the next message from the same session can be sent through.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 5.1 * @since 5.1
@ -48,9 +59,7 @@ class OrderedMessageSender implements MessageChannel {
private final Log logger; private final Log logger;
private final Queue<Message<?>> messages = new ConcurrentLinkedQueue<>(); private final Control control = new Control();
private final AtomicBoolean sendInProgress = new AtomicBoolean(false);
public OrderedMessageSender(MessageChannel channel, Log logger) { public OrderedMessageSender(MessageChannel channel, Log logger) {
@ -66,30 +75,40 @@ class OrderedMessageSender implements MessageChannel {
@Override @Override
public boolean send(Message<?> message, long timeout) { public boolean send(Message<?> message, long timeout) {
this.messages.add(message); this.control.addMessage(message);
trySend(); trySend();
return true; return true;
} }
private void trySend() { private void trySend() {
// Take sendInProgress flag only if queue is not empty if (this.control.acquireSendLock()) {
if (this.messages.isEmpty()) { sendMessages();
return;
}
if (this.sendInProgress.compareAndSet(false, true)) {
sendNextMessage();
} }
} }
private void sendNextMessage() { private void sendMessages() {
for (;;) { for ( ; ; ) {
Message<?> message = this.messages.poll(); Set<String> skipSet = new HashSet<>();
if (message != null) { for (Message<?> message : this.control.getMessagesToSend()) {
String sessionId = SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
Assert.notNull(sessionId, () -> "No session id in " + message.getHeaders());
if (skipSet.contains(sessionId)) {
continue;
}
if (!this.control.acquireSessionLock(sessionId)) {
skipSet.add(sessionId);
continue;
}
this.control.removeMessage(message);
try { try {
addCompletionCallback(message); getMutableAccessor(message).setHeader(COMPLETION_TASK_HEADER, (Runnable) () -> {
this.control.releaseSessionLock(sessionId);
if (this.control.hasRemainingWork()) {
trySend();
}
});
if (this.channel.send(message)) { if (this.channel.send(message)) {
return; continue;
} }
} }
catch (Throwable ex) { catch (Throwable ex) {
@ -97,20 +116,24 @@ class OrderedMessageSender implements MessageChannel {
logger.error("Failed to send " + message, ex); logger.error("Failed to send " + message, ex);
} }
} }
// We didn't send
this.control.releaseSessionLock(sessionId);
} }
else {
// We ran out of messages.. if (this.control.shouldYield()) {
this.sendInProgress.set(false); this.control.releaseSendLock();
trySend(); if (!this.control.shouldYield()) {
break; trySend();
}
return;
} }
} }
} }
private void addCompletionCallback(Message<?> msg) { private SimpMessageHeaderAccessor getMutableAccessor(Message<?> message) {
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(msg, SimpMessageHeaderAccessor.class); SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class);
Assert.isTrue(accessor != null && accessor.isMutable(), "Expected mutable SimpMessageHeaderAccessor"); Assert.isTrue(accessor != null && accessor.isMutable(), "Expected mutable SimpMessageHeaderAccessor");
accessor.setHeader(COMPLETION_TASK_HEADER, (Runnable) this::sendNextMessage); return accessor;
} }
@ -126,13 +149,13 @@ class OrderedMessageSender implements MessageChannel {
Assert.isInstanceOf(ExecutorSubscribableChannel.class, channel, Assert.isInstanceOf(ExecutorSubscribableChannel.class, channel,
"An ExecutorSubscribableChannel is required for `preservePublishOrder`"); "An ExecutorSubscribableChannel is required for `preservePublishOrder`");
ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel; ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel;
if (execChannel.getInterceptors().stream().noneMatch(i -> i instanceof CallbackInterceptor)) { if (execChannel.getInterceptors().stream().noneMatch(i -> i instanceof CompletionTaskInterceptor)) {
execChannel.addInterceptor(0, new CallbackInterceptor()); execChannel.addInterceptor(0, new CompletionTaskInterceptor());
} }
} }
else if (channel instanceof ExecutorSubscribableChannel) { else if (channel instanceof ExecutorSubscribableChannel) {
ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel; ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel;
execChannel.getInterceptors().stream().filter(i -> i instanceof CallbackInterceptor) execChannel.getInterceptors().stream().filter(i -> i instanceof CompletionTaskInterceptor)
.findFirst() .findFirst()
.map(execChannel::removeInterceptor); .map(execChannel::removeInterceptor);
@ -140,13 +163,71 @@ class OrderedMessageSender implements MessageChannel {
} }
private static class CallbackInterceptor implements ExecutorChannelInterceptor { /**
* Provides locks required for ordered message sending and execution within
* a session as well as storage for messages waiting to be sent.
*/
private static class Control {
private final Queue<Message<?>> messages = new ConcurrentLinkedQueue<>();
private final ConcurrentMap<String, Boolean> sessionsInProgress = new ConcurrentHashMap<>();
private final AtomicBoolean workInProgress = new AtomicBoolean(false);
public void addMessage(Message<?> message) {
this.messages.add(message);
}
public void removeMessage(Message<?> message) {
if (!this.messages.remove(message)) {
throw new IllegalStateException(
"Message " + message.getHeaders() + " was expected in the queue.");
}
}
public Collection<Message<?>> getMessagesToSend() {
return this.messages;
}
public boolean acquireSendLock() {
return this.workInProgress.compareAndSet(false, true);
}
public void releaseSendLock() {
this.workInProgress.set(false);
}
public boolean acquireSessionLock(String sessionId) {
if (this.sessionsInProgress.put(sessionId, Boolean.TRUE) != null) {
return false;
}
return true;
}
public void releaseSessionLock(String sessionId) {
this.sessionsInProgress.remove(sessionId);
}
public boolean hasRemainingWork() {
return !this.messages.isEmpty();
}
public boolean shouldYield() {
// No remaining work, or others can pick it up
return (!hasRemainingWork() || this.sessionsInProgress.size() > 0);
}
}
private static class CompletionTaskInterceptor implements ExecutorChannelInterceptor {
@Override @Override
public void afterMessageHandled( public void afterMessageHandled(
Message<?> msg, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) { Message<?> message, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) {
Runnable task = (Runnable) msg.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER); Runnable task = (Runnable) message.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER);
if (task != null) { if (task != null) {
task.run(); task.run();
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,10 @@
package org.springframework.messaging.simp.broker; package org.springframework.messaging.simp.broker;
import java.time.Duration;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -26,7 +30,12 @@ import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.ExecutorSubscribableChannel;
@ -43,6 +52,8 @@ public class OrderedMessageSenderTests {
private static final Log logger = LogFactory.getLog(OrderedMessageSenderTests.class); private static final Log logger = LogFactory.getLog(OrderedMessageSenderTests.class);
private static final Random random = new Random();
private OrderedMessageSender sender; private OrderedMessageSender sender;
@ -74,46 +85,97 @@ public class OrderedMessageSenderTests {
@Test @Test
public void test() throws InterruptedException { public void test() throws InterruptedException {
int start = 1; int sessionCount = 25;
int end = 1000; int messagesPerSessionCount = 500;
AtomicInteger index = new AtomicInteger(start); TestMessageHandler handler = new TestMessageHandler(sessionCount * messagesPerSessionCount);
AtomicReference<Object> result = new AtomicReference<>(); this.channel.subscribe(handler);
CountDownLatch latch = new CountDownLatch(1);
this.channel.subscribe(message -> { Publisher<Flux<Message<String>>> messageFluxes =
int expected = index.getAndIncrement(); Flux.range(1, sessionCount).map(sessionId ->
Integer actual = (Integer) message.getHeaders().getOrDefault("seq", -1); Flux.range(1, messagesPerSessionCount)
if (actual != expected) { .map(sequence -> createMessage(sessionId, sequence))
result.set("Expected: " + expected + ", but was: " + actual); .delayElements(Duration.ofMillis(Math.abs(random.nextLong()) % 5)));
Flux.merge(messageFluxes)
.doOnNext(message -> this.sender.send(message))
.blockLast();
handler.await(20, TimeUnit.SECONDS);
assertThat(handler.getDescription()).isEqualTo("Total processed: " + sessionCount * messagesPerSessionCount);
assertThat(handler.getSequenceBySession()).hasSize(sessionCount);
handler.getSequenceBySession().forEach((key, value) ->
assertThat(value.get()).as(key).isEqualTo(messagesPerSessionCount));
}
private static Message<String> createMessage(Integer sessionId, Integer sequence) {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
accessor.setSessionId("session" + sessionId);
accessor.setHeader("seq", sequence);
accessor.setLeaveMutable(true);
return MessageBuilder.createMessage("payload", accessor.getMessageHeaders());
}
private static class TestMessageHandler implements MessageHandler {
private final int totalExpected;
private final Map<String, AtomicInteger> sequenceBySession = new ConcurrentHashMap<>();
private final AtomicReference<String> description = new AtomicReference<>();
private final AtomicInteger totalReceived = new AtomicInteger();
private final CountDownLatch latch = new CountDownLatch(1);
TestMessageHandler(int totalExpected) {
this.totalExpected = totalExpected;
}
public void await(long timeout, TimeUnit timeUnit) throws InterruptedException {
latch.await(timeout, timeUnit);
}
public Map<String, AtomicInteger> getSequenceBySession() {
return sequenceBySession;
}
public String getDescription() {
return description.get();
}
@Override
public void handleMessage(Message<?> message) throws MessagingException {
String id = SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
Integer seq = (Integer) message.getHeaders().getOrDefault("seq", -1);
AtomicInteger prev = sequenceBySession.computeIfAbsent(id, i -> new AtomicInteger(0));
if (!prev.compareAndSet(seq - 1, seq)) {
description.set("Out of order, session=" + id + ", prev=" + prev + ", next=" + seq);
latch.countDown(); latch.countDown();
return; return;
} }
if (actual == 100 || actual == 200) {
if (seq == 100) {
try { try {
Thread.sleep(200); // Processing delay to cause other session messages to queue up
Thread.sleep(50);
} }
catch (InterruptedException ex) { catch (InterruptedException ex) {
result.set(ex.toString()); description.set(ex.toString());
latch.countDown(); latch.countDown();
return;
} }
} }
if (actual == end) {
result.set("Done"); int total = totalReceived.incrementAndGet();
description.set("Total processed: " + total);
if (total == totalExpected) {
latch.countDown(); latch.countDown();
} }
});
for (int i = start; i <= end; i++) {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
accessor.setHeader("seq", i);
accessor.setLeaveMutable(true);
this.sender.send(MessageBuilder.createMessage("payload", accessor.getMessageHeaders()));
} }
latch.await(10, TimeUnit.SECONDS);
assertThat(result.get()).isEqualTo("Done");
} }
} }