KAFKA-15625: Do not flush global state store at each commit (#15361)

Global state stores are currently flushed at each commit, which may impact performance, especially for EOS (commit each 200ms).
The goal of this improvement is to flush global state stores only when the delta between the current offset and the last checkpointed offset exceeds a threshold.
This is the same logic we apply on local state store, with a threshold of 10000 records.
The implementation only flushes if the time interval elapsed and the threshold of 10000 records is exceeded.

Reviewers: Jeff Kim <jeff.kim@confluent.io>, Bruno Cadonna <cadonna@apache.org>
This commit is contained in:
Ayoub Omari 2024-03-04 10:19:59 +01:00 committed by Bruno Cadonna
parent 4ee66bb269
commit 64845b9b07
7 changed files with 147 additions and 56 deletions

View File

@ -34,4 +34,6 @@ interface GlobalStateMaintainer {
void close(final boolean wipeStateStore) throws IOException;
void update(ConsumerRecord<byte[], byte[]> record);
void maybeCheckpoint();
}

View File

@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.streams.errors.DeserializationExceptionHandler;
import org.apache.kafka.streams.errors.StreamsException;
@ -45,18 +46,26 @@ public class GlobalStateUpdateTask implements GlobalStateMaintainer {
private final Map<String, RecordDeserializer> deserializers = new HashMap<>();
private final GlobalStateManager stateMgr;
private final DeserializationExceptionHandler deserializationExceptionHandler;
private final Time time;
private final long flushInterval;
private long lastFlush;
public GlobalStateUpdateTask(final LogContext logContext,
final ProcessorTopology topology,
final InternalProcessorContext processorContext,
final GlobalStateManager stateMgr,
final DeserializationExceptionHandler deserializationExceptionHandler) {
final DeserializationExceptionHandler deserializationExceptionHandler,
final Time time,
final long flushInterval
) {
this.logContext = logContext;
this.log = logContext.logger(getClass());
this.topology = topology;
this.stateMgr = stateMgr;
this.processorContext = processorContext;
this.deserializationExceptionHandler = deserializationExceptionHandler;
this.time = time;
this.flushInterval = flushInterval;
}
/**
@ -86,6 +95,7 @@ public class GlobalStateUpdateTask implements GlobalStateMaintainer {
}
initTopology();
processorContext.initialize();
lastFlush = time.milliseconds();
return stateMgr.changelogOffsets();
}
@ -150,5 +160,13 @@ public class GlobalStateUpdateTask implements GlobalStateMaintainer {
}
}
@Override
public void maybeCheckpoint() {
final long now = time.milliseconds();
if (now - flushInterval >= lastFlush && StateManagerUtil.checkpointNeeded(false, stateMgr.changelogOffsets(), offsets)) {
flushState();
lastFlush = now;
}
}
}

View File

@ -228,25 +228,17 @@ public class GlobalStreamThread extends Thread {
static class StateConsumer {
private final Consumer<byte[], byte[]> globalConsumer;
private final GlobalStateMaintainer stateMaintainer;
private final Time time;
private final Duration pollTime;
private final long flushInterval;
private final Logger log;
private long lastFlush;
StateConsumer(final LogContext logContext,
final Consumer<byte[], byte[]> globalConsumer,
final GlobalStateMaintainer stateMaintainer,
final Time time,
final Duration pollTime,
final long flushInterval) {
final Duration pollTime) {
this.log = logContext.logger(getClass());
this.globalConsumer = globalConsumer;
this.stateMaintainer = stateMaintainer;
this.time = time;
this.pollTime = pollTime;
this.flushInterval = flushInterval;
}
/**
@ -259,7 +251,6 @@ public class GlobalStreamThread extends Thread {
for (final Map.Entry<TopicPartition, Long> entry : partitionOffsets.entrySet()) {
globalConsumer.seek(entry.getKey(), entry.getValue());
}
lastFlush = time.milliseconds();
}
void pollAndUpdate() {
@ -267,11 +258,7 @@ public class GlobalStreamThread extends Thread {
for (final ConsumerRecord<byte[], byte[]> record : received) {
stateMaintainer.update(record);
}
final long now = time.milliseconds();
if (now - flushInterval >= lastFlush) {
stateMaintainer.flushState();
lastFlush = now;
}
stateMaintainer.maybeCheckpoint();
}
public void close(final boolean wipeStateStore) throws IOException {
@ -418,11 +405,11 @@ public class GlobalStreamThread extends Thread {
topology,
globalProcessorContext,
stateMgr,
config.defaultDeserializationExceptionHandler()
config.defaultDeserializationExceptionHandler(),
time,
config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG)
),
time,
Duration.ofMillis(config.getLong(StreamsConfig.POLL_MS_CONFIG)),
config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG)
Duration.ofMillis(config.getLong(StreamsConfig.POLL_MS_CONFIG))
);
try {

View File

@ -25,6 +25,7 @@ import org.apache.kafka.common.serialization.IntegerSerializer;
import org.apache.kafka.common.serialization.LongSerializer;
import org.apache.kafka.common.serialization.StringDeserializer;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler;
import org.apache.kafka.streams.errors.LogAndFailExceptionHandler;
@ -46,8 +47,6 @@ import java.util.Set;
import static java.util.Arrays.asList;
import static org.apache.kafka.streams.processor.internals.testutil.ConsumerRecordUtil.record;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
@ -71,8 +70,12 @@ public class GlobalStateTaskTest {
private final MockProcessorNode<?, ?, ?, ?> processorTwo = new MockProcessorNode<>();
private final Map<TopicPartition, Long> offsets = new HashMap<>();
private File testDirectory = TestUtils.tempDirectory("global-store");
private final File testDirectory = TestUtils.tempDirectory("global-store");
private final NoOpProcessorContext context = new NoOpProcessorContext();
private final MockTime time = new MockTime();
private final long flushInterval = 1000L;
private final long currentOffsetT1 = 50;
private final long currentOffsetT2 = 100;
private ProcessorTopology topology;
private GlobalStateManagerStub stateMgr;
@ -101,7 +104,9 @@ public class GlobalStateTaskTest {
topology,
context,
stateMgr,
new LogAndFailExceptionHandler()
new LogAndFailExceptionHandler(),
time,
flushInterval
);
}
@ -188,7 +193,9 @@ public class GlobalStateTaskTest {
topology,
context,
stateMgr,
new LogAndContinueExceptionHandler()
new LogAndContinueExceptionHandler(),
time,
flushInterval
);
final byte[] key = new LongSerializer().serialize(topic2, 1L);
final byte[] recordValue = new IntegerSerializer().serialize(topic2, 10);
@ -203,7 +210,9 @@ public class GlobalStateTaskTest {
topology,
context,
stateMgr,
new LogAndContinueExceptionHandler()
new LogAndContinueExceptionHandler(),
time,
flushInterval
);
final byte[] key = new IntegerSerializer().serialize(topic2, 1);
final byte[] recordValue = new LongSerializer().serialize(topic2, 10L);
@ -217,10 +226,13 @@ public class GlobalStateTaskTest {
final Map<TopicPartition, Long> expectedOffsets = new HashMap<>();
expectedOffsets.put(t1, 52L);
expectedOffsets.put(t2, 100L);
globalStateTask.initialize();
globalStateTask.update(record(topic1, 1, 51, "foo".getBytes(), "foo".getBytes()));
globalStateTask.update(record(topic1, 1, currentOffsetT1 + 1, "foo".getBytes(), "foo".getBytes()));
globalStateTask.flushState();
assertEquals(expectedOffsets, stateMgr.changelogOffsets());
assertTrue(stateMgr.flushed);
}
@Test
@ -228,12 +240,93 @@ public class GlobalStateTaskTest {
final Map<TopicPartition, Long> expectedOffsets = new HashMap<>();
expectedOffsets.put(t1, 102L);
expectedOffsets.put(t2, 100L);
globalStateTask.initialize();
globalStateTask.update(record(topic1, 1, 101, "foo".getBytes(), "foo".getBytes()));
globalStateTask.update(record(topic1, 1, currentOffsetT1 + 51L, "foo".getBytes(), "foo".getBytes()));
globalStateTask.flushState();
assertThat(stateMgr.changelogOffsets(), equalTo(expectedOffsets));
assertEquals(expectedOffsets, stateMgr.changelogOffsets());
assertTrue(stateMgr.checkpointWritten);
}
@Test
public void shouldNotCheckpointIfNotReceivedEnoughRecords() {
globalStateTask.initialize();
globalStateTask.update(record(topic1, 1, currentOffsetT1 + 9000L, "foo".getBytes(), "foo".getBytes()));
time.sleep(flushInterval); // flush interval elapsed
globalStateTask.maybeCheckpoint();
assertEquals(offsets, stateMgr.changelogOffsets());
assertFalse(stateMgr.flushed);
assertFalse(stateMgr.checkpointWritten);
}
@Test
public void shouldNotCheckpointWhenFlushIntervalHasNotLapsed() {
globalStateTask.initialize();
// offset delta exceeded
globalStateTask.update(record(topic1, 1, currentOffsetT1 + 10000L, "foo".getBytes(), "foo".getBytes()));
time.sleep(flushInterval / 2);
globalStateTask.maybeCheckpoint();
assertEquals(offsets, stateMgr.changelogOffsets());
assertFalse(stateMgr.flushed);
assertFalse(stateMgr.checkpointWritten);
}
@Test
public void shouldCheckpointIfReceivedEnoughRecordsAndFlushIntervalHasElapsed() {
final Map<TopicPartition, Long> expectedOffsets = new HashMap<>();
expectedOffsets.put(t1, 10051L); // topic1 advanced with 10001 records
expectedOffsets.put(t2, 100L);
globalStateTask.initialize();
time.sleep(flushInterval); // flush interval elapsed
// 10000 records received since last flush => do not flush
globalStateTask.update(record(topic1, 1, currentOffsetT1 + 9999L, "foo".getBytes(), "foo".getBytes()));
globalStateTask.maybeCheckpoint();
assertEquals(offsets, stateMgr.changelogOffsets());
assertFalse(stateMgr.flushed);
assertFalse(stateMgr.checkpointWritten);
// 1 more record received => triggers the flush
globalStateTask.update(record(topic1, 1, currentOffsetT1 + 10000L, "foo".getBytes(), "foo".getBytes()));
globalStateTask.maybeCheckpoint();
assertEquals(expectedOffsets, stateMgr.changelogOffsets());
assertTrue(stateMgr.flushed);
assertTrue(stateMgr.checkpointWritten);
}
@Test
public void shouldCheckpointIfReceivedEnoughRecordsFromMultipleTopicsAndFlushIntervalElapsed() {
final byte[] integerBytes = new IntegerSerializer().serialize(topic2, 1);
final Map<TopicPartition, Long> expectedOffsets = new HashMap<>();
expectedOffsets.put(t1, 9050L); // topic1 advanced with 9000 records
expectedOffsets.put(t2, 1101L); // topic2 advanced with 1001 records
globalStateTask.initialize();
time.sleep(flushInterval);
// received 9000 records in topic1
globalStateTask.update(record(topic1, 1, currentOffsetT1 + 8999L, "foo".getBytes(), "foo".getBytes()));
// received 1001 records in topic2
globalStateTask.update(record(topic2, 1, currentOffsetT2 + 1000L, integerBytes, integerBytes));
globalStateTask.maybeCheckpoint();
assertEquals(expectedOffsets, stateMgr.changelogOffsets());
assertTrue(stateMgr.flushed);
assertTrue(stateMgr.checkpointWritten);
}
@Test
public void shouldWipeGlobalStateDirectory() throws Exception {
assertTrue(stateMgr.baseDir().exists());

View File

@ -21,7 +21,6 @@ import org.apache.kafka.clients.consumer.MockConsumer;
import org.apache.kafka.clients.consumer.OffsetResetStrategy;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Utils;
import org.junit.Before;
import org.junit.Test;
@ -32,16 +31,13 @@ import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class StateConsumerTest {
private static final long FLUSH_INTERVAL = 1000L;
private final TopicPartition topicOne = new TopicPartition("topic-one", 1);
private final TopicPartition topicTwo = new TopicPartition("topic-two", 1);
private final MockTime time = new MockTime();
private final MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
private final Map<TopicPartition, Long> partitionOffsets = new HashMap<>();
private final LogContext logContext = new LogContext("test ");
@ -53,7 +49,7 @@ public class StateConsumerTest {
partitionOffsets.put(topicOne, 20L);
partitionOffsets.put(topicTwo, 30L);
stateMaintainer = new TaskStub(partitionOffsets);
stateConsumer = new GlobalStreamThread.StateConsumer(logContext, consumer, stateMaintainer, time, Duration.ofMillis(10L), FLUSH_INTERVAL);
stateConsumer = new GlobalStreamThread.StateConsumer(logContext, consumer, stateMaintainer, Duration.ofMillis(10L));
}
@Test
@ -76,6 +72,7 @@ public class StateConsumerTest {
consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 21L, new byte[0], new byte[0]));
stateConsumer.pollAndUpdate();
assertEquals(2, stateMaintainer.updatedPartitions.get(topicOne).intValue());
assertTrue(stateMaintainer.flushed);
}
@Test
@ -87,27 +84,9 @@ public class StateConsumerTest {
stateConsumer.pollAndUpdate();
assertEquals(1, stateMaintainer.updatedPartitions.get(topicOne).intValue());
assertEquals(2, stateMaintainer.updatedPartitions.get(topicTwo).intValue());
}
@Test
public void shouldFlushStoreWhenFlushIntervalHasLapsed() {
stateConsumer.initialize();
consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 20L, new byte[0], new byte[0]));
time.sleep(FLUSH_INTERVAL);
stateConsumer.pollAndUpdate();
assertTrue(stateMaintainer.flushed);
}
@Test
public void shouldNotFlushOffsetsWhenFlushIntervalHasNotLapsed() {
stateConsumer.initialize();
consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 20L, new byte[0], new byte[0]));
time.sleep(FLUSH_INTERVAL / 2);
stateConsumer.pollAndUpdate();
assertFalse(stateMaintainer.flushed);
}
@Test
public void shouldCloseConsumer() throws IOException {
stateConsumer.close(false);
@ -161,6 +140,10 @@ public class StateConsumerTest {
updatedPartitions.put(tp, updatedPartitions.get(tp) + 1);
}
@Override
public void maybeCheckpoint() {
flushState();
}
}
}
}

View File

@ -35,6 +35,8 @@ public class GlobalStateManagerStub implements GlobalStateManager {
private final File baseDirectory;
public boolean initialized;
public boolean closed;
public boolean flushed;
public boolean checkpointWritten;
public GlobalStateManagerStub(final Set<String> storeNames,
final Map<TopicPartition, Long> offsets,
@ -64,7 +66,9 @@ public class GlobalStateManagerStub implements GlobalStateManager {
final CommitCallback checkpoint) {}
@Override
public void flush() {}
public void flush() {
flushed = true;
}
@Override
public void close() {
@ -77,7 +81,9 @@ public class GlobalStateManagerStub implements GlobalStateManager {
}
@Override
public void checkpoint() {}
public void checkpoint() {
checkpointWritten = true;
}
@Override
public StateStore getStore(final String name) {

View File

@ -459,7 +459,9 @@ public class TopologyTestDriver implements Closeable {
globalTopology,
globalProcessorContext,
globalStateManager,
new LogAndContinueExceptionHandler()
new LogAndContinueExceptionHandler(),
mockWallClockTime,
streamsConfig.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG)
);
globalStateTask.initialize();
globalProcessorContext.setRecordContext(null);