KAFKA-14299: Initialize tasks in state updater (#12795)

The state updater code path puts tasks into an
"initialization queue", with created, but not initialized tasks.
These are later, during the event-loop, initialized and added
to the state updater. This might lead to losing track of those 
tasks - in particular it is possible to create
tasks twice, if we do not go once around `runLoop` to initialize
the task. This leads to `IllegalStateExceptions`. 

By handing the task to the state updater immediately and let the
state updater initialize the task, we can fulfil our promise to 
preserve the invariant "every task is owned by either the task 
registry or the state updater".

Reviewer: Bruno Cadonna <cadonna@apache.org>
This commit is contained in:
Lucas Brutschy 2022-11-14 10:00:29 +01:00 committed by GitHub
parent e422a67d3f
commit a55071a99f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 99 additions and 141 deletions

View File

@ -121,21 +121,27 @@ public class DefaultStateUpdater implements StateUpdater {
}
private void performActionsOnTasks() {
tasksAndActionsLock.lock();
try {
for (final TaskAndAction taskAndAction : getTasksAndActions()) {
while (!tasksAndActions.isEmpty()) {
Task toInitialize = null;
tasksAndActionsLock.lock();
try {
final TaskAndAction taskAndAction = tasksAndActions.remove();
final Action action = taskAndAction.getAction();
switch (action) {
case ADD:
addTask(taskAndAction.getTask());
toInitialize = taskAndAction.getTask();
break;
case REMOVE:
removeTask(taskAndAction.getTaskId());
break;
}
} finally {
tasksAndActionsLock.unlock();
}
if (toInitialize != null) {
initializeTask(toInitialize);
}
} finally {
tasksAndActionsLock.unlock();
}
}
@ -250,22 +256,32 @@ public class DefaultStateUpdater implements StateUpdater {
pausedTasks.clear();
}
private List<TaskAndAction> getTasksAndActions() {
final List<TaskAndAction> tasksAndActionsToProcess = new ArrayList<>(tasksAndActions);
tasksAndActions.clear();
return tasksAndActionsToProcess;
private void addTask(final Task task) {
final Task existingTask = updatingTasks.putIfAbsent(task.id(), task);
if (existingTask != null) {
throw new IllegalStateException((existingTask.isActive() ? "Active" : "Standby") + " task " + task.id() + " already exist, " +
"should not try to add another " + (task.isActive() ? "active" : "standby") + " task with the same id. " + BUG_ERROR_MESSAGE);
}
}
private void addTask(final Task task) {
private void initializeTask(final Task task) {
if (task.state() == Task.State.CREATED) {
try {
task.initializeIfNeeded();
} catch (final StreamsException streamsException) {
addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks(new ExceptionAndTasks(Collections.singleton(task), streamsException));
return;
}
}
postInitializeTask(task);
}
private void postInitializeTask(final Task task) {
if (isStateless(task)) {
addToRestoredTasks((StreamTask) task);
updatingTasks.remove(task.id());
log.info("Stateless active task " + task.id() + " was added to the restored tasks of the state updater");
} else {
final Task existingTask = updatingTasks.putIfAbsent(task.id(), task);
if (existingTask != null) {
throw new IllegalStateException((existingTask.isActive() ? "Active" : "Standby") + " task " + task.id() + " already exist, " +
"should not try to add another " + (task.isActive() ? "active" : "standby") + " task with the same id. " + BUG_ERROR_MESSAGE);
}
changelogReader.register(task.changelogPartitions(), task.stateManager());
if (task.isActive()) {
log.info("Stateful active task " + task.id() + " was added to the state updater");
@ -486,11 +502,11 @@ public class DefaultStateUpdater implements StateUpdater {
}
private void verifyStateFor(final Task task) {
if (task.isActive() && task.state() != State.RESTORING) {
throw new IllegalStateException("Active task " + task.id() + " is not in state RESTORING. " + BUG_ERROR_MESSAGE);
if (task.isActive() && task.state() != State.RESTORING && task.state() != State.CREATED) {
throw new IllegalStateException("Active task " + task.id() + " is not in state RESTORING or CREATED. " + BUG_ERROR_MESSAGE);
}
if (!task.isActive() && task.state() != State.RUNNING) {
throw new IllegalStateException("Standby task " + task.id() + " is not in state RUNNING. " + BUG_ERROR_MESSAGE);
if (!task.isActive() && task.state() != State.RUNNING && task.state() != State.CREATED) {
throw new IllegalStateException("Standby task " + task.id() + " is not in state RUNNING or CREATED. " + BUG_ERROR_MESSAGE);
}
}

View File

@ -400,8 +400,8 @@ public class TaskManager {
tasks.addActiveTasks(newActiveTasks);
tasks.addStandbyTasks(newStandbyTask);
} else {
tasks.addPendingTaskToInit(newActiveTasks);
tasks.addPendingTaskToInit(newStandbyTask);
Stream.concat(newActiveTasks.stream(), newStandbyTask.stream())
.forEach(stateUpdater::add);
}
}
@ -735,7 +735,6 @@ public class TaskManager {
public boolean checkStateUpdater(final long now,
final java.util.function.Consumer<Set<TopicPartition>> offsetResetter) {
addTasksToStateUpdater();
if (stateUpdater.hasExceptionsAndFailedTasks()) {
handleExceptionsFromStateUpdater();
}
@ -758,7 +757,6 @@ public class TaskManager {
newTask = task.isActive() ?
convertActiveToStandby((StreamTask) task, inputPartitions) :
convertStandbyToActive((StandbyTask) task, inputPartitions);
newTask.initializeIfNeeded();
stateUpdater.add(newTask);
} catch (final RuntimeException e) {
final TaskId taskId = task.id();
@ -818,22 +816,6 @@ public class TaskManager {
}
}
private void addTasksToStateUpdater() {
final Map<TaskId, RuntimeException> taskExceptions = new LinkedHashMap<>();
for (final Task task : tasks.drainPendingTaskToInit()) {
try {
task.initializeIfNeeded();
stateUpdater.add(task);
} catch (final RuntimeException e) {
// need to add task back to the bookkeeping to be handled by the stream thread
tasks.addTask(task);
taskExceptions.put(task.id(), e);
}
}
maybeThrowTaskExceptions(taskExceptions);
}
public void handleExceptionsFromStateUpdater() {
final Map<TaskId, RuntimeException> taskExceptions = new LinkedHashMap<>();

View File

@ -172,18 +172,6 @@ class Tasks implements TasksRegistry {
return pendingUpdateAction != null && pendingUpdateAction.getAction() == action;
}
@Override
public Set<Task> drainPendingTaskToInit() {
final Set<Task> result = new HashSet<>(pendingTasksToInit);
pendingTasksToInit.clear();
return result;
}
@Override
public void addPendingTaskToInit(final Collection<Task> tasks) {
pendingTasksToInit.addAll(tasks);
}
@Override
public void addActiveTasks(final Collection<Task> newTasks) {
if (!newTasks.isEmpty()) {

View File

@ -51,10 +51,6 @@ public interface TasksRegistry {
void addPendingTaskToCloseClean(final TaskId taskId);
Set<Task> drainPendingTaskToInit();
void addPendingTaskToInit(final Collection<Task> tasks);
boolean removePendingActiveTaskToSuspend(final TaskId taskId);
void addPendingActiveTaskToSuspend(final TaskId taskId);

View File

@ -213,18 +213,18 @@ class DefaultStateUpdaterTest {
}
private void shouldThrowIfActiveTaskNotInStateRestoring(final StreamTask task) {
shouldThrowIfTaskNotInGivenState(task, State.RESTORING);
shouldThrowIfTaskNotInGivenState(task, mkSet(State.RESTORING, State.CREATED));
}
@Test
public void shouldThrowIfStandbyTaskNotInStateRunning() {
final StandbyTask task = standbyTask(TASK_0_0, mkSet(TOPIC_PARTITION_B_0)).build();
shouldThrowIfTaskNotInGivenState(task, State.RUNNING);
shouldThrowIfTaskNotInGivenState(task, mkSet(State.RUNNING, State.CREATED));
}
private void shouldThrowIfTaskNotInGivenState(final Task task, final State correctState) {
private void shouldThrowIfTaskNotInGivenState(final Task task, final Set<State> correctStates) {
for (final State state : State.values()) {
if (state != correctState) {
if (!correctStates.contains(state)) {
when(task.state()).thenReturn(state);
assertThrows(IllegalStateException.class, () -> stateUpdater.add(task));
}
@ -371,6 +371,54 @@ class DefaultStateUpdaterTest {
assertTrue(stateUpdater.restoresActiveTasks());
}
@Test
public void shouldInitializeAddedTasksInCreatedState() throws Exception {
final StreamTask task1 = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0))
.inState(State.CREATED).build();
final StandbyTask task2 = standbyTask(TASK_0_1, mkSet(TOPIC_PARTITION_A_1))
.inState(State.CREATED).build();
final StandbyTask task3 = standbyTask(TASK_1_0, mkSet(TOPIC_PARTITION_B_0))
.inState(State.RUNNING).build();
final StreamTask task4 = statefulTask(TASK_A_0_0, mkSet(TOPIC_PARTITION_C_0))
.inState(State.RESTORING).build();
stateUpdater.start();
stateUpdater.add(task1);
stateUpdater.add(task2);
stateUpdater.add(task3);
stateUpdater.add(task4);
verifyUpdatingTasks(task1, task2, task3, task4);
verify(task1).initializeIfNeeded();
verify(task2).initializeIfNeeded();
verify(task3, never()).initializeIfNeeded();
verify(task3, never()).initializeIfNeeded();
}
@Test
public void shouldForwardExceptionsFromFailedInitializations() throws Exception {
final StreamTask task1 = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0))
.inState(State.CREATED).build();
final StandbyTask task2 = standbyTask(TASK_0_1, mkSet(TOPIC_PARTITION_A_1))
.inState(State.CREATED).build();
final StandbyTask task3 = standbyTask(TASK_1_1, mkSet(TOPIC_PARTITION_B_0))
.inState(State.CREATED).build();
final TaskCorruptedException taskCorruptedException = new TaskCorruptedException(Collections.singleton(TASK_0_0));
doThrow(taskCorruptedException).when(task1).initializeIfNeeded();
final StreamsException streamsException = new StreamsException("Kaboom!");
doThrow(streamsException).when(task2).initializeIfNeeded();
stateUpdater.start();
stateUpdater.add(task1);
stateUpdater.add(task2);
stateUpdater.add(task3);
verifyExceptionsAndFailedTasks(
new ExceptionAndTasks(mkSet(task1), taskCorruptedException),
new ExceptionAndTasks(mkSet(task2), streamsException)
);
verifyUpdatingTasks(task3);
verify(task3).initializeIfNeeded();
}
@Test
public void shouldReturnTrueForRestoreActiveTasksIfTaskUpdating() throws Exception {
final StreamTask task = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0))
@ -1693,7 +1741,10 @@ class DefaultStateUpdaterTest {
&& failedTasks.size() == expectedExceptionAndTasks.size();
},
VERIFICATION_TIMEOUT,
"Did not get all exceptions and failed tasks within the given timeout!"
() -> String.format(
"Did not get all exceptions and failed tasks within the given timeout! Got: %s",
failedTasks
)
);
}
@ -1711,7 +1762,10 @@ class DefaultStateUpdaterTest {
&& failedTasks.size() == expectedFailedTasks.size();
},
VERIFICATION_TIMEOUT,
"Did not get all exceptions and failed tasks within the given timeout!"
() -> String.format(
"Did not get all exceptions and failed tasks within the given timeout! Got: %s",
failedTasks
)
);
}

View File

@ -517,7 +517,7 @@ public class TaskManagerTest {
);
verify(activeTaskCreator, standbyTaskCreator);
Mockito.verify(tasks).addPendingTaskToInit(createdTasks);
Mockito.verify(stateUpdater).add(activeTaskToBeCreated);
}
@Test
@ -540,7 +540,7 @@ public class TaskManagerTest {
);
verify(activeTaskCreator, standbyTaskCreator);
Mockito.verify(tasks).addPendingTaskToInit(createdTasks);
Mockito.verify(stateUpdater).add(standbyTaskToBeCreated);
}
@Test
@ -750,26 +750,6 @@ public class TaskManagerTest {
Mockito.verify(activeTaskToClose).closeClean();
}
@Test
public void shouldAddTasksToStateUpdater() {
final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions)
.withInputPartitions(taskId00Partitions)
.inState(State.RESTORING).build();
final StandbyTask task01 = standbyTask(taskId01, taskId01ChangelogPartitions)
.withInputPartitions(taskId01Partitions)
.inState(State.RUNNING).build();
final TasksRegistry tasks = mock(TasksRegistry.class);
when(tasks.drainPendingTaskToInit()).thenReturn(mkSet(task00, task01));
taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, tasks, true);
taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
Mockito.verify(task00).initializeIfNeeded();
Mockito.verify(task01).initializeIfNeeded();
Mockito.verify(stateUpdater).add(task00);
Mockito.verify(stateUpdater).add(task01);
}
@Test
public void shouldRecycleTasksRemovedFromStateUpdater() {
final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions)
@ -801,8 +781,6 @@ public class TaskManagerTest {
verify(activeTaskCreator, standbyTaskCreator);
Mockito.verify(task00).suspend();
Mockito.verify(task01).suspend();
Mockito.verify(task00Converted).initializeIfNeeded();
Mockito.verify(task01Converted).initializeIfNeeded();
Mockito.verify(stateUpdater).add(task00Converted);
Mockito.verify(stateUpdater).add(task01Converted);
}
@ -928,8 +906,6 @@ public class TaskManagerTest {
taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
verify(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
Mockito.verify(convertedTask0).initializeIfNeeded();
Mockito.verify(convertedTask1).initializeIfNeeded();
Mockito.verify(stateUpdater).add(convertedTask0);
Mockito.verify(stateUpdater).add(convertedTask1);
Mockito.verify(taskToClose).closeClean();
@ -1105,7 +1081,6 @@ public class TaskManagerTest {
verify(activeTaskCreator, standbyTaskCreator);
Mockito.verify(statefulTask).suspend();
Mockito.verify(standbyTask).initializeIfNeeded();
Mockito.verify(stateUpdater).add(standbyTask);
}
@ -1129,30 +1104,6 @@ public class TaskManagerTest {
Mockito.verify(statefulTask).closeDirty();
}
@Test
public void shouldHandleExceptionThrownDuringTaskInitInRecycleRestoredTask() {
final StreamTask statefulTask = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.CLOSED)
.withInputPartitions(taskId00Partitions).build();
final StandbyTask standbyTask = standbyTask(taskId00, taskId00ChangelogPartitions)
.inState(State.CREATED)
.withInputPartitions(taskId00Partitions).build();
final TaskManager taskManager = setUpRecycleRestoredTask(statefulTask);
expect(standbyTaskCreator.createStandbyTaskFromActive(statefulTask, statefulTask.inputPartitions()))
.andStubReturn(standbyTask);
doThrow(StreamsException.class).when(standbyTask).initializeIfNeeded();
replay(standbyTaskCreator);
assertThrows(
StreamsException.class,
() -> taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)
);
verify(standbyTaskCreator);
Mockito.verify(stateUpdater, never()).add(any());
Mockito.verify(standbyTask).closeDirty();
}
private TaskManager setUpRecycleRestoredTask(final StreamTask statefulTask) {
final TasksRegistry tasks = mock(TasksRegistry.class);
when(tasks.removePendingTaskToRecycle(statefulTask.id())).thenReturn(taskId00Partitions);
@ -1443,35 +1394,6 @@ public class TaskManagerTest {
assertEquals("Tasks [0_1, 0_0] are corrupted and hence need to be re-initialized", thrown.getMessage());
}
@Test
public void shouldRethrowTaskCorruptedExceptionFromInitialization() {
final StreamTask statefulTask0 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.CREATED)
.withInputPartitions(taskId00Partitions).build();
final StreamTask statefulTask1 = statefulTask(taskId01, taskId01ChangelogPartitions)
.inState(State.CREATED)
.withInputPartitions(taskId01Partitions).build();
final StreamTask statefulTask2 = statefulTask(taskId02, taskId02ChangelogPartitions)
.inState(State.CREATED)
.withInputPartitions(taskId02Partitions).build();
final TasksRegistry tasks = mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2, tasks, true);
when(tasks.drainPendingTaskToInit()).thenReturn(mkSet(statefulTask0, statefulTask1, statefulTask2));
doThrow(new TaskCorruptedException(Collections.singleton(statefulTask0.id))).when(statefulTask0).initializeIfNeeded();
doThrow(new TaskCorruptedException(Collections.singleton(statefulTask1.id))).when(statefulTask1).initializeIfNeeded();
final TaskCorruptedException thrown = assertThrows(
TaskCorruptedException.class,
() -> taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)
);
Mockito.verify(tasks).addTask(statefulTask0);
Mockito.verify(tasks).addTask(statefulTask1);
Mockito.verify(stateUpdater).add(statefulTask2);
assertEquals(mkSet(taskId00, taskId01), thrown.corruptedTasks());
assertEquals("Tasks [0_1, 0_0] are corrupted and hence need to be re-initialized", thrown.getMessage());
}
@Test
public void shouldIdempotentlyUpdateSubscriptionFromActiveAssignment() {
final TopicPartition newTopicPartition = new TopicPartition("topic2", 1);