KAFKA-18913: Start state updater in task manager (#19889)

Updated the code to start the State Updater Thread only after the Stream
Thread is started.

Changes done :
1. Moved the starting of the StateUpdater thread to a new init method in
the TaskManager.
2. Called the init of TaskManager in the run method of the StreamThread.
3. Updated the test cases in the StreamThreadTest to mimic the
aforementioned behaviour.

Reviewers: Bruno Cadonna <cadonna@apache.org>
This commit is contained in:
Janindu Pathirana 2025-06-06 14:44:41 +05:30 committed by GitHub
parent aaed164be6
commit 4d6cf3efef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 57 additions and 11 deletions

View File

@ -454,7 +454,7 @@ public class StreamThread extends Thread implements ProcessingThread {
final DefaultTaskManager schedulingTaskManager =
maybeCreateSchedulingTaskManager(processingThreadsEnabled, stateUpdaterEnabled, topologyMetadata, time, threadId, tasks);
final StateUpdater stateUpdater =
maybeCreateAndStartStateUpdater(
maybeCreateStateUpdater(
stateUpdaterEnabled,
streamsMetrics,
config,
@ -635,7 +635,7 @@ public class StreamThread extends Thread implements ProcessingThread {
return null;
}
private static StateUpdater maybeCreateAndStartStateUpdater(final boolean stateUpdaterEnabled,
private static StateUpdater maybeCreateStateUpdater(final boolean stateUpdaterEnabled,
final StreamsMetricsImpl streamsMetrics,
final StreamsConfig streamsConfig,
final Consumer<byte[], byte[]> restoreConsumer,
@ -646,7 +646,7 @@ public class StreamThread extends Thread implements ProcessingThread {
final int threadIdx) {
if (stateUpdaterEnabled) {
final String name = clientId + STATE_UPDATER_ID_SUBSTRING + threadIdx;
final StateUpdater stateUpdater = new DefaultStateUpdater(
return new DefaultStateUpdater(
name,
streamsMetrics.metricsRegistry(),
streamsConfig,
@ -655,8 +655,6 @@ public class StreamThread extends Thread implements ProcessingThread {
topologyMetadata,
time
);
stateUpdater.start();
return stateUpdater;
} else {
return null;
}
@ -885,6 +883,9 @@ public class StreamThread extends Thread implements ProcessingThread {
}
boolean cleanRun = false;
try {
if (stateUpdaterEnabled) {
taskManager.init();
}
cleanRun = runLoop();
} catch (final Throwable e) {
failedStreamThreadSensor.record();

View File

@ -149,6 +149,11 @@ public class TaskManager {
);
}
void init() {
if (stateUpdater != null) {
this.stateUpdater.start();
}
}
void setMainConsumer(final Consumer<byte[], byte[]> mainConsumer) {
this.mainConsumer = mainConsumer;
}

View File

@ -110,6 +110,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.Mockito;
@ -917,6 +918,7 @@ public class StreamThreadTest {
thread = createStreamThread(CLIENT_ID, config);
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.setState(StreamThread.State.PARTITIONS_REVOKED);
final TaskId task1 = new TaskId(0, t1p1.partition());
@ -1291,6 +1293,7 @@ public class StreamThreadTest {
thread = createStreamThread(CLIENT_ID, new StreamsConfig(props));
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.rebalanceListener().onPartitionsRevoked(Collections.emptyList());
final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
@ -1548,6 +1551,7 @@ public class StreamThreadTest {
consumer.updatePartitions(topic1, Collections.singletonList(new PartitionInfo(topic1, 1, null, null, null)));
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet());
final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
@ -1613,6 +1617,7 @@ public class StreamThreadTest {
internalTopologyBuilder.addSink("out", "output", null, null, null, "name");
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet());
final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
@ -1695,6 +1700,7 @@ public class StreamThreadTest {
internalTopologyBuilder.buildTopology();
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet());
final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
@ -1789,6 +1795,7 @@ public class StreamThreadTest {
consumer.updatePartitions(topic1, Collections.singletonList(new PartitionInfo(topic1, 1, null, null, null)));
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet());
final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
@ -1853,6 +1860,7 @@ public class StreamThreadTest {
internalTopologyBuilder.addSink("out", "output", null, null, null, "name");
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet());
final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
@ -1934,6 +1942,7 @@ public class StreamThreadTest {
);
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet());
final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
@ -1994,6 +2003,7 @@ public class StreamThreadTest {
restoreConsumer.updateBeginningOffsets(offsets);
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet());
final Map<TaskId, Set<TopicPartition>> standbyTasks = new HashMap<>();
@ -2257,6 +2267,7 @@ public class StreamThreadTest {
thread = createStreamThread(CLIENT_ID, config);
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet());
final List<TopicPartition> assignedPartitions = new ArrayList<>();
@ -2336,6 +2347,7 @@ public class StreamThreadTest {
thread = createStreamThread(CLIENT_ID, stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet());
final List<TopicPartition> assignedPartitions = new ArrayList<>();
@ -2533,6 +2545,7 @@ public class StreamThreadTest {
thread = createStreamThread(CLIENT_ID, new StreamsConfig(properties));
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.setState(StreamThread.State.PARTITIONS_REVOKED);
final TaskId task1 = new TaskId(0, t1p1.partition());
@ -3019,6 +3032,7 @@ public class StreamThreadTest {
thread = createStreamThread(CLIENT_ID, config);
thread.setState(StreamThread.State.STARTING);
thread.taskManager().init();
thread.setState(StreamThread.State.PARTITIONS_REVOKED);
final TaskId task1 = new TaskId(0, t1p1.partition());
@ -3392,6 +3406,7 @@ public class StreamThreadTest {
thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(State.STARTING);
thread.taskManager().init();
final Map<String, KafkaFuture<Uuid>> clientInstanceIdFutures = thread.clientInstanceIds(Duration.ZERO);
@ -3416,6 +3431,7 @@ public class StreamThreadTest {
public void shouldReturnErrorIfMainConsumerInstanceIdNotInitialized(final boolean stateUpdaterEnabled, final boolean processingThreadsEnabled) {
thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(State.STARTING);
thread.taskManager().init();
final Map<String, KafkaFuture<Uuid>> consumerFutures = thread.clientInstanceIds(Duration.ZERO);
@ -3432,6 +3448,7 @@ public class StreamThreadTest {
public void shouldReturnErrorIfRestoreConsumerInstanceIdNotInitialized(final boolean stateUpdaterEnabled, final boolean processingThreadsEnabled) {
thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(State.STARTING);
thread.taskManager().init();
final Map<String, KafkaFuture<Uuid>> consumerFutures = thread.clientInstanceIds(Duration.ZERO);
@ -3448,6 +3465,7 @@ public class StreamThreadTest {
public void shouldReturnErrorIfProducerInstanceIdNotInitialized(final boolean stateUpdaterEnabled, final boolean processingThreadsEnabled) {
thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(State.STARTING);
thread.taskManager().init();
final Map<String, KafkaFuture<Uuid>> producerFutures = thread.clientInstanceIds(Duration.ZERO);
@ -3465,6 +3483,7 @@ public class StreamThreadTest {
clientSupplier.consumer.disableTelemetry();
thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(State.STARTING);
thread.taskManager().init();
final Map<String, KafkaFuture<Uuid>> consumerFutures = thread.clientInstanceIds(Duration.ZERO);
@ -3482,6 +3501,7 @@ public class StreamThreadTest {
thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(State.STARTING);
thread.taskManager().init();
final Map<String, KafkaFuture<Uuid>> consumerFutures = thread.clientInstanceIds(Duration.ZERO);
@ -3501,6 +3521,7 @@ public class StreamThreadTest {
thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(State.STARTING);
thread.taskManager().init();
final Map<String, KafkaFuture<Uuid>> producerFutures = thread.clientInstanceIds(Duration.ZERO);
@ -3518,6 +3539,7 @@ public class StreamThreadTest {
clientSupplier.consumer.injectTimeoutException(-1);
thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(State.STARTING);
thread.taskManager().init();
final Map<String, KafkaFuture<Uuid>> consumerFutures = thread.clientInstanceIds(Duration.ZERO);
@ -3542,6 +3564,7 @@ public class StreamThreadTest {
clientSupplier.restoreConsumer.injectTimeoutException(-1);
thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(State.STARTING);
thread.taskManager().init();
final Map<String, KafkaFuture<Uuid>> consumerFutures = thread.clientInstanceIds(Duration.ZERO);
@ -3569,6 +3592,7 @@ public class StreamThreadTest {
thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled);
thread.setState(State.STARTING);
thread.taskManager().init();
final Map<String, KafkaFuture<Uuid>> producerFutures = thread.clientInstanceIds(Duration.ZERO);
@ -3585,9 +3609,10 @@ public class StreamThreadTest {
);
}
@Test
public void testNamedTopologyWithStreamsProtocol() {
final Properties props = configProps(false, false, false);
@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testNamedTopologyWithStreamsProtocol(final boolean stateUpdaterEnabled) {
final Properties props = configProps(false, stateUpdaterEnabled, false);
props.setProperty(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.toString());
final StreamsConfig config = new StreamsConfig(props);
final InternalTopologyBuilder topologyBuilder = new InternalTopologyBuilder(
@ -3644,9 +3669,10 @@ public class StreamThreadTest {
assertTrue(thread.streamsRebalanceData().isEmpty());
}
@Test
public void testStreamsRebalanceDataWithExtraCopartition() {
final Properties props = configProps(false, false, false);
@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testStreamsRebalanceDataWithExtraCopartition(final boolean stateUpdaterEnabled) {
final Properties props = configProps(false, stateUpdaterEnabled, false);
props.setProperty(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.toString());
internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1);

View File

@ -61,6 +61,8 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
@ -4837,6 +4839,18 @@ public class TaskManagerTest {
assertEquals(Collections.singletonMap(taskId00, startupTask), taskManager.standbyTaskMap());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
public void shouldStartStateUpdaterOnInit(final boolean stateUpdaterEnabled) {
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, stateUpdaterEnabled);
taskManager.init();
if (stateUpdaterEnabled) {
verify(stateUpdater).start();
} else {
verify(stateUpdater, never()).start();
}
}
private static KafkaFutureImpl<DeletedRecords> completedFuture() {
final KafkaFutureImpl<DeletedRecords> futureDeletedRecords = new KafkaFutureImpl<>();
futureDeletedRecords.complete(null);