From 4d6cf3efef15b3c6b1511c28b882b5486ff1245c Mon Sep 17 00:00:00 2001 From: Janindu Pathirana Date: Fri, 6 Jun 2025 14:44:41 +0530 Subject: [PATCH] KAFKA-18913: Start state updater in task manager (#19889) Updated the code to start the State Updater Thread only after the Stream Thread is started. Changes done : 1. Moved the starting of the StateUpdater thread to a new init method in the TaskManager. 2. Called the init of TaskManager in the run method of the StreamThread. 3. Updated the test cases in the StreamThreadTest to mimic the aforementioned behaviour. Reviewers: Bruno Cadonna --- .../processor/internals/StreamThread.java | 11 +++--- .../processor/internals/TaskManager.java | 5 +++ .../processor/internals/StreamThreadTest.java | 38 ++++++++++++++++--- .../processor/internals/TaskManagerTest.java | 14 +++++++ 4 files changed, 57 insertions(+), 11 deletions(-) diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java index fdc5e8df4bc..47775c53652 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -454,7 +454,7 @@ public class StreamThread extends Thread implements ProcessingThread { final DefaultTaskManager schedulingTaskManager = maybeCreateSchedulingTaskManager(processingThreadsEnabled, stateUpdaterEnabled, topologyMetadata, time, threadId, tasks); final StateUpdater stateUpdater = - maybeCreateAndStartStateUpdater( + maybeCreateStateUpdater( stateUpdaterEnabled, streamsMetrics, config, @@ -635,7 +635,7 @@ public class StreamThread extends Thread implements ProcessingThread { return null; } - private static StateUpdater maybeCreateAndStartStateUpdater(final boolean stateUpdaterEnabled, + private static StateUpdater maybeCreateStateUpdater(final boolean stateUpdaterEnabled, final StreamsMetricsImpl streamsMetrics, final StreamsConfig streamsConfig, final Consumer restoreConsumer, @@ -646,7 +646,7 @@ public class StreamThread extends Thread implements ProcessingThread { final int threadIdx) { if (stateUpdaterEnabled) { final String name = clientId + STATE_UPDATER_ID_SUBSTRING + threadIdx; - final StateUpdater stateUpdater = new DefaultStateUpdater( + return new DefaultStateUpdater( name, streamsMetrics.metricsRegistry(), streamsConfig, @@ -655,8 +655,6 @@ public class StreamThread extends Thread implements ProcessingThread { topologyMetadata, time ); - stateUpdater.start(); - return stateUpdater; } else { return null; } @@ -885,6 +883,9 @@ public class StreamThread extends Thread implements ProcessingThread { } boolean cleanRun = false; try { + if (stateUpdaterEnabled) { + taskManager.init(); + } cleanRun = runLoop(); } catch (final Throwable e) { failedStreamThreadSensor.record(); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java index 1eaf8298622..67d009b037f 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java @@ -149,6 +149,11 @@ public class TaskManager { ); } + void init() { + if (stateUpdater != null) { + this.stateUpdater.start(); + } + } void setMainConsumer(final Consumer mainConsumer) { this.mainConsumer = mainConsumer; } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java index 54230d11d3b..44cd11459f7 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -110,6 +110,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.Mockito; @@ -917,6 +918,7 @@ public class StreamThreadTest { thread = createStreamThread(CLIENT_ID, config); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.setState(StreamThread.State.PARTITIONS_REVOKED); final TaskId task1 = new TaskId(0, t1p1.partition()); @@ -1291,6 +1293,7 @@ public class StreamThreadTest { thread = createStreamThread(CLIENT_ID, new StreamsConfig(props)); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.rebalanceListener().onPartitionsRevoked(Collections.emptyList()); final Map> activeTasks = new HashMap<>(); @@ -1548,6 +1551,7 @@ public class StreamThreadTest { consumer.updatePartitions(topic1, Collections.singletonList(new PartitionInfo(topic1, 1, null, null, null))); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); final Map> activeTasks = new HashMap<>(); @@ -1613,6 +1617,7 @@ public class StreamThreadTest { internalTopologyBuilder.addSink("out", "output", null, null, null, "name"); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); final Map> activeTasks = new HashMap<>(); @@ -1695,6 +1700,7 @@ public class StreamThreadTest { internalTopologyBuilder.buildTopology(); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); final Map> activeTasks = new HashMap<>(); @@ -1789,6 +1795,7 @@ public class StreamThreadTest { consumer.updatePartitions(topic1, Collections.singletonList(new PartitionInfo(topic1, 1, null, null, null))); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); final Map> activeTasks = new HashMap<>(); @@ -1853,6 +1860,7 @@ public class StreamThreadTest { internalTopologyBuilder.addSink("out", "output", null, null, null, "name"); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); final Map> activeTasks = new HashMap<>(); @@ -1934,6 +1942,7 @@ public class StreamThreadTest { ); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); final Map> activeTasks = new HashMap<>(); @@ -1994,6 +2003,7 @@ public class StreamThreadTest { restoreConsumer.updateBeginningOffsets(offsets); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); final Map> standbyTasks = new HashMap<>(); @@ -2257,6 +2267,7 @@ public class StreamThreadTest { thread = createStreamThread(CLIENT_ID, config); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); final List assignedPartitions = new ArrayList<>(); @@ -2336,6 +2347,7 @@ public class StreamThreadTest { thread = createStreamThread(CLIENT_ID, stateUpdaterEnabled, processingThreadsEnabled); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); final List assignedPartitions = new ArrayList<>(); @@ -2533,6 +2545,7 @@ public class StreamThreadTest { thread = createStreamThread(CLIENT_ID, new StreamsConfig(properties)); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.setState(StreamThread.State.PARTITIONS_REVOKED); final TaskId task1 = new TaskId(0, t1p1.partition()); @@ -3019,6 +3032,7 @@ public class StreamThreadTest { thread = createStreamThread(CLIENT_ID, config); thread.setState(StreamThread.State.STARTING); + thread.taskManager().init(); thread.setState(StreamThread.State.PARTITIONS_REVOKED); final TaskId task1 = new TaskId(0, t1p1.partition()); @@ -3392,6 +3406,7 @@ public class StreamThreadTest { thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled); thread.setState(State.STARTING); + thread.taskManager().init(); final Map> clientInstanceIdFutures = thread.clientInstanceIds(Duration.ZERO); @@ -3416,6 +3431,7 @@ public class StreamThreadTest { public void shouldReturnErrorIfMainConsumerInstanceIdNotInitialized(final boolean stateUpdaterEnabled, final boolean processingThreadsEnabled) { thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled); thread.setState(State.STARTING); + thread.taskManager().init(); final Map> consumerFutures = thread.clientInstanceIds(Duration.ZERO); @@ -3432,6 +3448,7 @@ public class StreamThreadTest { public void shouldReturnErrorIfRestoreConsumerInstanceIdNotInitialized(final boolean stateUpdaterEnabled, final boolean processingThreadsEnabled) { thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled); thread.setState(State.STARTING); + thread.taskManager().init(); final Map> consumerFutures = thread.clientInstanceIds(Duration.ZERO); @@ -3448,6 +3465,7 @@ public class StreamThreadTest { public void shouldReturnErrorIfProducerInstanceIdNotInitialized(final boolean stateUpdaterEnabled, final boolean processingThreadsEnabled) { thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled); thread.setState(State.STARTING); + thread.taskManager().init(); final Map> producerFutures = thread.clientInstanceIds(Duration.ZERO); @@ -3465,6 +3483,7 @@ public class StreamThreadTest { clientSupplier.consumer.disableTelemetry(); thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled); thread.setState(State.STARTING); + thread.taskManager().init(); final Map> consumerFutures = thread.clientInstanceIds(Duration.ZERO); @@ -3482,6 +3501,7 @@ public class StreamThreadTest { thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled); thread.setState(State.STARTING); + thread.taskManager().init(); final Map> consumerFutures = thread.clientInstanceIds(Duration.ZERO); @@ -3501,6 +3521,7 @@ public class StreamThreadTest { thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled); thread.setState(State.STARTING); + thread.taskManager().init(); final Map> producerFutures = thread.clientInstanceIds(Duration.ZERO); @@ -3518,6 +3539,7 @@ public class StreamThreadTest { clientSupplier.consumer.injectTimeoutException(-1); thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled); thread.setState(State.STARTING); + thread.taskManager().init(); final Map> consumerFutures = thread.clientInstanceIds(Duration.ZERO); @@ -3542,6 +3564,7 @@ public class StreamThreadTest { clientSupplier.restoreConsumer.injectTimeoutException(-1); thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled); thread.setState(State.STARTING); + thread.taskManager().init(); final Map> consumerFutures = thread.clientInstanceIds(Duration.ZERO); @@ -3569,6 +3592,7 @@ public class StreamThreadTest { thread = createStreamThread("clientId", stateUpdaterEnabled, processingThreadsEnabled); thread.setState(State.STARTING); + thread.taskManager().init(); final Map> producerFutures = thread.clientInstanceIds(Duration.ZERO); @@ -3585,9 +3609,10 @@ public class StreamThreadTest { ); } - @Test - public void testNamedTopologyWithStreamsProtocol() { - final Properties props = configProps(false, false, false); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testNamedTopologyWithStreamsProtocol(final boolean stateUpdaterEnabled) { + final Properties props = configProps(false, stateUpdaterEnabled, false); props.setProperty(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.toString()); final StreamsConfig config = new StreamsConfig(props); final InternalTopologyBuilder topologyBuilder = new InternalTopologyBuilder( @@ -3644,9 +3669,10 @@ public class StreamThreadTest { assertTrue(thread.streamsRebalanceData().isEmpty()); } - @Test - public void testStreamsRebalanceDataWithExtraCopartition() { - final Properties props = configProps(false, false, false); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testStreamsRebalanceDataWithExtraCopartition(final boolean stateUpdaterEnabled) { + final Properties props = configProps(false, stateUpdaterEnabled, false); props.setProperty(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.toString()); internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java index 68abf455a76..26a1523131b 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java @@ -61,6 +61,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @@ -4837,6 +4839,18 @@ public class TaskManagerTest { assertEquals(Collections.singletonMap(taskId00, startupTask), taskManager.standbyTaskMap()); } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void shouldStartStateUpdaterOnInit(final boolean stateUpdaterEnabled) { + final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, stateUpdaterEnabled); + taskManager.init(); + if (stateUpdaterEnabled) { + verify(stateUpdater).start(); + } else { + verify(stateUpdater, never()).start(); + } + } + private static KafkaFutureImpl completedFuture() { final KafkaFutureImpl futureDeletedRecords = new KafkaFutureImpl<>(); futureDeletedRecords.complete(null);