diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java index 017d99f0d0f..063d117502f 100644 --- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java @@ -81,7 +81,7 @@ public class RecordAccumulator { private final IncompleteBatches incomplete; // The following variables are only accessed by the sender thread, so we don't need to protect them. private final Set muted; - private int drainIndex; + private final Map nodesDrainIndex; private final TransactionManager transactionManager; private long nextBatchExpiryTimeMs = Long.MAX_VALUE; // the earliest time (absolute) a batch will expire. @@ -115,7 +115,6 @@ public class RecordAccumulator { TransactionManager transactionManager, BufferPool bufferPool) { this.log = logContext.logger(RecordAccumulator.class); - this.drainIndex = 0; this.closed = false; this.flushesInProgress = new AtomicInteger(0); this.appendsInProgress = new AtomicInteger(0); @@ -130,6 +129,7 @@ public class RecordAccumulator { this.muted = new HashSet<>(); this.time = time; this.apiVersions = apiVersions; + nodesDrainIndex = new HashMap<>(); this.transactionManager = transactionManager; registerMetrics(metrics, metricGrpName); } @@ -559,13 +559,14 @@ public class RecordAccumulator { int size = 0; List parts = cluster.partitionsForNode(node.id()); List ready = new ArrayList<>(); - /* to make starvation less likely this loop doesn't start at 0 */ + /* to make starvation less likely each node has it's own drainIndex */ + int drainIndex = getDrainIndex(node.idString()); int start = drainIndex = drainIndex % parts.size(); do { PartitionInfo part = parts.get(drainIndex); TopicPartition tp = new TopicPartition(part.topic(), part.partition()); - this.drainIndex = (this.drainIndex + 1) % parts.size(); - + updateDrainIndex(node.idString(), drainIndex); + drainIndex = (drainIndex + 1) % parts.size(); // Only proceed if the partition has no in-flight batches. if (isMuted(tp)) continue; @@ -638,6 +639,14 @@ public class RecordAccumulator { return ready; } + private int getDrainIndex(String idString) { + return nodesDrainIndex.computeIfAbsent(idString, s -> 0); + } + + private void updateDrainIndex(String idString, int drainIndex) { + nodesDrainIndex.put(idString, drainIndex); + } + /** * Drain all the data for the given nodes and collate them into a list of batches that will fit within the specified * size on a per-node basis. This method attempts to avoid choosing the same topic-node over and over. diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java index 06ed1ce1f12..7c3518a1367 100644 --- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java @@ -52,6 +52,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Deque; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -60,7 +61,6 @@ import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; - import static java.util.Arrays.asList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -98,6 +98,66 @@ public class RecordAccumulatorTest { this.metrics.close(); } + @Test + public void testDrainBatches() throws Exception { + // test case: node1(tp1,tp2) , node2(tp3,tp4) + // add tp-4 + int partition4 = 3; + TopicPartition tp4 = new TopicPartition(topic, partition4); + PartitionInfo part4 = new PartitionInfo(topic, partition4, node2, null, null); + + long batchSize = value.length + DefaultRecordBatch.RECORD_BATCH_OVERHEAD; + RecordAccumulator accum = createTestRecordAccumulator((int) batchSize, 1024, CompressionType.NONE, 10); + Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1, part2, part3, part4), + Collections.emptySet(), Collections.emptySet()); + + // initial data + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + accum.append(tp2, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + accum.append(tp3, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + accum.append(tp4, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + + // drain batches from 2 nodes: node1 => tp1, node2 => tp3, because the max request size is full after the first batch drained + Map> batches1 = accum.drain(cluster, new HashSet(Arrays.asList(node1, node2)), (int) batchSize, 0); + verifyTopicPartitionInBatches(batches1, tp1, tp3); + + // add record for tp1, tp3 + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + accum.append(tp3, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + + // drain batches from 2 nodes: node1 => tp2, node2 => tp4, because the max request size is full after the first batch drained + // The drain index should start from next topic partition, that is, node1 => tp2, node2 => tp4 + Map> batches2 = accum.drain(cluster, new HashSet(Arrays.asList(node1, node2)), (int) batchSize, 0); + verifyTopicPartitionInBatches(batches2, tp2, tp4); + + // make sure in next run, the drain index will start from the beginning + Map> batches3 = accum.drain(cluster, new HashSet(Arrays.asList(node1, node2)), (int) batchSize, 0); + verifyTopicPartitionInBatches(batches3, tp1, tp3); + + // add record for tp2, tp3, tp4 and mute the tp4 + accum.append(tp2, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + accum.append(tp3, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + accum.append(tp4, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + accum.mutePartition(tp4); + // drain batches from 2 nodes: node1 => tp2, node2 => tp3 (because tp4 is muted) + Map> batches4 = accum.drain(cluster, new HashSet(Arrays.asList(node1, node2)), (int) batchSize, 0); + verifyTopicPartitionInBatches(batches4, tp2, tp3); + } + + private void verifyTopicPartitionInBatches(Map> batches, TopicPartition... tp) { + assertEquals(tp.length, batches.size()); + List topicPartitionsInBatch = new ArrayList(); + for (Map.Entry> entry : batches.entrySet()) { + List batchList = entry.getValue(); + assertEquals(1, batchList.size()); + topicPartitionsInBatch.add(batchList.get(0).topicPartition); + } + + for (int i = 0; i < tp.length; i++) { + assertEquals(tp[i], topicPartitionsInBatch.get(i)); + } + } + @Test public void testFull() throws Exception { long now = time.milliseconds();