diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java index 38b00232c82..03199d294ca 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java @@ -46,6 +46,15 @@ public interface ChangelogReader extends ChangelogRegister { */ Set completedChangelogs(); + /** + * Returns whether all changelog partitions were completely read. + * + * Since changelog partitions for standby tasks are never completely read, this method will always return + * {@code false} if the changelog reader registered changelog partitions for standby tasks. + * + * @return {@code true} if all changelog partitions were completely read and no standby changelog partitions are read, + * {@code false} otherwise + */ boolean allChangelogsCompleted(); /** diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java index 0b6558d8acf..55935d3e21a 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java @@ -37,12 +37,14 @@ import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; import java.util.stream.Collectors; public class DefaultStateUpdater implements StateUpdater { @@ -54,13 +56,13 @@ public class DefaultStateUpdater implements StateUpdater { private final ChangelogReader changelogReader; private final AtomicBoolean isRunning = new AtomicBoolean(true); - private final java.util.function.Consumer> offsetResetter; + private final Consumer> offsetResetter; private final Map updatingTasks = new HashMap<>(); private final Logger log; public StateUpdaterThread(final String name, final ChangelogReader changelogReader, - final java.util.function.Consumer> offsetResetter) { + final Consumer> offsetResetter) { super(name); this.changelogReader = changelogReader; this.offsetResetter = offsetResetter; @@ -74,30 +76,44 @@ public class DefaultStateUpdater implements StateUpdater { return updatingTasks.values(); } + public Collection getUpdatingStandbyTasks() { + return updatingTasks.values().stream() + .filter(t -> !t.isActive()) + .map(t -> (StandbyTask) t) + .collect(Collectors.toList()); + } + + public boolean onlyStandbyTasksLeft() { + return !updatingTasks.isEmpty() && updatingTasks.values().stream().allMatch(t -> !t.isActive()); + } + @Override public void run() { + log.info("State updater thread started"); try { while (isRunning.get()) { try { - performActionsOnTasks(); - restoreTasks(); - waitIfAllChangelogsCompletelyRead(); + runOnce(); } catch (final InterruptedException interruptedException) { return; } } } catch (final RuntimeException anyOtherException) { - log.error("An unexpected error occurred within the state updater thread: " + anyOtherException); - final ExceptionAndTasks exceptionAndTasks = new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), anyOtherException); - updatingTasks.clear(); - failedTasks.add(exceptionAndTasks); - isRunning.set(false); + handleRuntimeException(anyOtherException); } finally { clear(); + shutdownGate.countDown(); + log.info("State updater thread shutdown"); } } - private void performActionsOnTasks() throws InterruptedException { + private void runOnce() throws InterruptedException { + performActionsOnTasks(); + restoreTasks(); + waitIfAllChangelogsCompletelyRead(); + } + + private void performActionsOnTasks() { tasksAndActionsLock.lock(); try { for (final TaskAndAction taskAndAction : getTasksAndActions()) { @@ -114,10 +130,8 @@ public class DefaultStateUpdater implements StateUpdater { } } - private void restoreTasks() throws InterruptedException { + private void restoreTasks() { try { - // ToDo: Prioritize restoration of active tasks over standby tasks - // changelogReader.enforceRestoreActive(); changelogReader.restore(updatingTasks); } catch (final TaskCorruptedException taskCorruptedException) { handleTaskCorruptedException(taskCorruptedException); @@ -127,11 +141,20 @@ public class DefaultStateUpdater implements StateUpdater { final Set completedChangelogs = changelogReader.completedChangelogs(); final List activeTasks = updatingTasks.values().stream().filter(Task::isActive).collect(Collectors.toList()); for (final Task task : activeTasks) { - endRestorationIfChangelogsCompletelyRead(task, completedChangelogs); + maybeCompleteRestoration((StreamTask) task, completedChangelogs); } } + private void handleRuntimeException(final RuntimeException runtimeException) { + log.error("An unexpected error occurred within the state updater thread: " + runtimeException); + final ExceptionAndTasks exceptionAndTasks = new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), runtimeException); + updatingTasks.clear(); + failedTasks.add(exceptionAndTasks); + isRunning.set(false); + } + private void handleTaskCorruptedException(final TaskCorruptedException taskCorruptedException) { + log.info("Encountered task corrupted exception: ", taskCorruptedException); final Set corruptedTaskIds = taskCorruptedException.corruptedTasks(); final Set corruptedTasks = new HashSet<>(); for (final TaskId taskId : corruptedTaskIds) { @@ -145,6 +168,7 @@ public class DefaultStateUpdater implements StateUpdater { } private void handleStreamsException(final StreamsException streamsException) { + log.info("Encountered streams exception: ", streamsException); final ExceptionAndTasks exceptionAndTasks; if (streamsException.taskId().isPresent()) { exceptionAndTasks = handleStreamsExceptionWithTask(streamsException); @@ -191,8 +215,8 @@ public class DefaultStateUpdater implements StateUpdater { tasksAndActions.clear(); restoredActiveTasks.clear(); } finally { - tasksAndActionsLock.unlock(); restoredActiveTasksLock.unlock(); + tasksAndActionsLock.unlock(); } changelogReader.clear(); updatingTasks.clear(); @@ -206,9 +230,20 @@ public class DefaultStateUpdater implements StateUpdater { private void addTask(final Task task) { if (isStateless(task)) { + log.debug("Stateless active task " + task.id() + " was added to the state updater"); addTaskToRestoredTasks((StreamTask) task); } else { - updatingTasks.put(task.id(), task); + if (task.isActive()) { + updatingTasks.put(task.id(), task); + log.debug("Stateful active task " + task.id() + " was added to the state updater"); + changelogReader.enforceRestoreActive(); + } else { + updatingTasks.put(task.id(), task); + log.debug("Standby task " + task.id() + " was added to the state updater"); + if (updatingTasks.size() == 1) { + changelogReader.transitToUpdateStandby(); + } + } } } @@ -216,13 +251,17 @@ public class DefaultStateUpdater implements StateUpdater { return task.changelogPartitions().isEmpty() && task.isActive(); } - private void endRestorationIfChangelogsCompletelyRead(final Task task, - final Set restoredChangelogs) { + private void maybeCompleteRestoration(final StreamTask task, + final Set restoredChangelogs) { final Collection taskChangelogPartitions = task.changelogPartitions(); if (restoredChangelogs.containsAll(taskChangelogPartitions)) { task.completeRestoration(offsetResetter); - addTaskToRestoredTasks((StreamTask) task); + log.debug("Stateful active task " + task.id() + " completed restoration"); + addTaskToRestoredTasks(task); updatingTasks.remove(task.id()); + if (onlyStandbyTasksLeft()) { + changelogReader.transitToUpdateStandby(); + } } } @@ -230,6 +269,7 @@ public class DefaultStateUpdater implements StateUpdater { restoredActiveTasksLock.lock(); try { restoredActiveTasks.add(task); + log.debug("Active task " + task.id() + " was added to the restored tasks"); restoredActiveTasksCondition.signalAll(); } finally { restoredActiveTasksLock.unlock(); @@ -253,7 +293,7 @@ public class DefaultStateUpdater implements StateUpdater { private final Time time; private final ChangelogReader changelogReader; - private final java.util.function.Consumer> offsetResetter; + private final Consumer> offsetResetter; private final Queue tasksAndActions = new LinkedList<>(); private final Lock tasksAndActionsLock = new ReentrantLock(); private final Condition tasksAndActionsCondition = tasksAndActionsLock.newCondition(); @@ -261,11 +301,12 @@ public class DefaultStateUpdater implements StateUpdater { private final Lock restoredActiveTasksLock = new ReentrantLock(); private final Condition restoredActiveTasksCondition = restoredActiveTasksLock.newCondition(); private final BlockingQueue failedTasks = new LinkedBlockingQueue<>(); + private CountDownLatch shutdownGate; private StateUpdaterThread stateUpdaterThread = null; public DefaultStateUpdater(final ChangelogReader changelogReader, - final java.util.function.Consumer> offsetResetter, + final Consumer> offsetResetter, final Time time) { this.changelogReader = changelogReader; this.offsetResetter = offsetResetter; @@ -277,6 +318,7 @@ public class DefaultStateUpdater implements StateUpdater { if (stateUpdaterThread == null) { stateUpdaterThread = new StateUpdaterThread("state-updater", changelogReader, offsetResetter); stateUpdaterThread.start(); + shutdownGate = new CountDownLatch(1); } verifyStateFor(task); @@ -294,6 +336,9 @@ public class DefaultStateUpdater implements StateUpdater { if (task.isActive() && task.state() != State.RESTORING) { throw new IllegalStateException("Active task " + task.id() + " is not in state RESTORING. " + BUG_ERROR_MESSAGE); } + if (!task.isActive() && task.state() != State.RUNNING) { + throw new IllegalStateException("Standby task " + task.id() + " is not in state RUNNING. " + BUG_ERROR_MESSAGE); + } } @Override @@ -315,9 +360,8 @@ public class DefaultStateUpdater implements StateUpdater { final boolean elapsed = restoredActiveTasksCondition.await(deadline - now, TimeUnit.MILLISECONDS); now = time.milliseconds(); } - while (!restoredActiveTasks.isEmpty()) { - result.add(restoredActiveTasks.poll()); - } + result.addAll(restoredActiveTasks); + restoredActiveTasks.clear(); } finally { restoredActiveTasksLock.unlock(); } @@ -352,21 +396,44 @@ public class DefaultStateUpdater implements StateUpdater { allTasks.addAll(restoredActiveTasks); return Collections.unmodifiableSet(allTasks); } finally { - tasksAndActionsLock.unlock(); restoredActiveTasksLock.unlock(); + tasksAndActionsLock.unlock(); } } + @Override + public Set getStandbyTasks() { + tasksAndActionsLock.lock(); + try { + final Set standbyTasks = new HashSet<>(); + standbyTasks.addAll(tasksAndActions.stream() + .filter(t -> t.action == Action.ADD) + .filter(t -> !t.task.isActive()) + .map(t -> (StandbyTask) t.task) + .collect(Collectors.toList()) + ); + standbyTasks.addAll(getUpdatingStandbyTasks()); + return Collections.unmodifiableSet(standbyTasks); + } finally { + tasksAndActionsLock.unlock(); + } + } + + public Set getUpdatingStandbyTasks() { + return Collections.unmodifiableSet(new HashSet<>(stateUpdaterThread.getUpdatingStandbyTasks())); + } + @Override public void shutdown(final Duration timeout) { if (stateUpdaterThread != null) { stateUpdaterThread.isRunning.set(false); stateUpdaterThread.interrupt(); try { - stateUpdaterThread.join(timeout.toMillis()); + if (!shutdownGate.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) { + throw new StreamsException("State updater thread did not shutdown within the timeout"); + } stateUpdaterThread = null; - } catch (final InterruptedException e) { - // ignore + } catch (final InterruptedException ignored) { } } } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java index 9e98e0d2c9e..d2d4ab71ad3 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java @@ -65,14 +65,24 @@ public interface StateUpdater { /** * Get all tasks (active and standby) that are managed by the state updater. * - * @return list of tasks managed by the state updater + * @return set of tasks managed by the state updater */ Set getAllTasks(); + /** + * Get standby tasks that are managed by the state updater. + * + * @return set of standby tasks managed by the state updater + */ + Set getStandbyTasks(); + /** * Shuts down the state updater. * * @param timeout duration how long to wait until the state updater is shut down + * + * @throws + * org.apache.kafka.streams.errors.StreamsException if the state updater thread cannot shutdown within the timeout */ void shutdown(final Duration timeout); } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java index e94d8b14883..c9fa1abede3 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java @@ -25,6 +25,7 @@ import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTas import org.apache.kafka.streams.processor.internals.Task.State; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import org.mockito.InOrder; import java.time.Duration; import java.util.ArrayList; @@ -49,7 +50,10 @@ import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -61,9 +65,11 @@ class DefaultStateUpdaterTest { private final static TopicPartition TOPIC_PARTITION_A_0 = new TopicPartition("topicA", 0); private final static TopicPartition TOPIC_PARTITION_B_0 = new TopicPartition("topicB", 0); private final static TopicPartition TOPIC_PARTITION_C_0 = new TopicPartition("topicC", 0); + private final static TopicPartition TOPIC_PARTITION_D_0 = new TopicPartition("topicD", 0); private final static TaskId TASK_0_0 = new TaskId(0, 0); private final static TaskId TASK_0_2 = new TaskId(0, 2); private final static TaskId TASK_1_0 = new TaskId(1, 0); + private final static TaskId TASK_1_1 = new TaskId(1, 1); private final ChangelogReader changelogReader = mock(ChangelogReader.class); private final java.util.function.Consumer> offsetResetter = topicPartitions -> { }; @@ -101,17 +107,31 @@ class DefaultStateUpdaterTest { @Test public void shouldThrowIfStatelessTaskNotInStateRestoring() { - shouldThrowIfTaskNotInStateRestoring(createStatelessTask(TASK_0_0)); + shouldThrowIfActiveTaskNotInStateRestoring(createStatelessTask(TASK_0_0)); } @Test public void shouldThrowIfStatefulTaskNotInStateRestoring() { - shouldThrowIfTaskNotInStateRestoring(createActiveStatefulTask(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0))); + shouldThrowIfActiveTaskNotInStateRestoring(createActiveStatefulTask(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0))); } - private void shouldThrowIfTaskNotInStateRestoring(final StreamTask task) { - when(task.state()).thenReturn(State.CREATED); - assertThrows(IllegalStateException.class, () -> stateUpdater.add(task)); + private void shouldThrowIfActiveTaskNotInStateRestoring(final StreamTask task) { + shouldThrowIfTaskNotInGivenState(task, State.RESTORING); + } + + @Test + public void shouldThrowIfStandbyTaskNotInStateRunning() { + final StandbyTask task = createStandbyTask(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_B_0)); + shouldThrowIfTaskNotInGivenState(task, State.RUNNING); + } + + private void shouldThrowIfTaskNotInGivenState(final Task task, final State correctState) { + for (final State state : State.values()) { + if (state != correctState) { + when(task.state()).thenReturn(state); + assertThrows(IllegalStateException.class, () -> stateUpdater.add(task)); + } + } } @Test @@ -133,18 +153,7 @@ class DefaultStateUpdaterTest { stateUpdater.add(task); } - final Set expectedRestoredTasks = mkSet(tasks); - final Set restoredTasks = new HashSet<>(); - waitForCondition( - () -> { - restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT))); - return restoredTasks.size() == expectedRestoredTasks.size(); - }, - VERIFICATION_TIMEOUT, - "Did not get any restored active task within the given timeout!" - ); - assertTrue(restoredTasks.containsAll(expectedRestoredTasks)); - assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(task -> task.state() == State.RESTORING).count()); + verifyRestoredActiveTasks(tasks); assertTrue(stateUpdater.getAllTasks().isEmpty()); } @@ -163,21 +172,12 @@ class DefaultStateUpdaterTest { stateUpdater.add(task); - final Set expectedRestoredTasks = Collections.singleton(task); - final Set restoredTasks = new HashSet<>(); - waitForCondition( - () -> { - restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT))); - return restoredTasks.size() == expectedRestoredTasks.size(); - }, - VERIFICATION_TIMEOUT, - "Did not get any restored active task within the given timeout!" - ); - assertTrue(restoredTasks.containsAll(expectedRestoredTasks)); - assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(t -> t.state() == State.RESTORING).count()); + verifyRestoredActiveTasks(task); assertTrue(stateUpdater.getAllTasks().isEmpty()); - verify(changelogReader, atLeast(3)).restore(anyMap()); + verify(changelogReader, times(1)).enforceRestoreActive(); + verify(changelogReader, atLeast(1)).restore(anyMap()); verify(task).completeRestoration(offsetResetter); + verify(changelogReader, never()).transitToUpdateStandby(); } @Test @@ -200,48 +200,125 @@ class DefaultStateUpdaterTest { stateUpdater.add(task2); stateUpdater.add(task3); - final Set expectedRestoredTasks = mkSet(task3, task1, task2); - final Set restoredTasks = new HashSet<>(); - waitForCondition( - () -> { - restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT))); - return restoredTasks.size() == expectedRestoredTasks.size(); - }, - VERIFICATION_TIMEOUT, - "Did not get any restored active task within the given timeout!" - ); - assertTrue(restoredTasks.containsAll(expectedRestoredTasks)); - assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(t -> t.state() == State.RESTORING).count()); + verifyRestoredActiveTasks(task3, task1, task2); assertTrue(stateUpdater.getAllTasks().isEmpty()); + verify(changelogReader, times(3)).enforceRestoreActive(); verify(changelogReader, atLeast(4)).restore(anyMap()); verify(task3).completeRestoration(offsetResetter); verify(task1).completeRestoration(offsetResetter); verify(task2).completeRestoration(offsetResetter); + verify(changelogReader, never()).transitToUpdateStandby(); + } + + @Test + public void shouldUpdateSingleStandbyTask() throws Exception { + final StandbyTask task = createStandbyTaskInStateRunning( + TASK_0_0, + Arrays.asList(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0) + ); + shouldUpdateStandbyTasks(task); + } + + @Test + public void shouldUpdateMultipleStandbyTasks() throws Exception { + final StandbyTask task1 = createStandbyTaskInStateRunning(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); + final StandbyTask task3 = createStandbyTaskInStateRunning(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); + shouldUpdateStandbyTasks(task1, task2, task3); + } + + private void shouldUpdateStandbyTasks(final StandbyTask... tasks) throws Exception { + when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet()); + when(changelogReader.allChangelogsCompleted()).thenReturn(false); + + for (final StandbyTask task : tasks) { + stateUpdater.add(task); + } + + verifyUpdatingStandbyTasks(tasks); + verify(changelogReader, times(1)).transitToUpdateStandby(); + verify(changelogReader, timeout(VERIFICATION_TIMEOUT).atLeast(1)).restore(anyMap()); + verify(changelogReader, never()).enforceRestoreActive(); + } + + @Test + public void shouldRestoreActiveStatefulTasksAndUpdateStandbyTasks() throws Exception { + final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); + final StandbyTask task3 = createStandbyTaskInStateRunning(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); + final StandbyTask task4 = createStandbyTaskInStateRunning(TASK_1_1, Collections.singletonList(TOPIC_PARTITION_D_0)); + when(changelogReader.completedChangelogs()) + .thenReturn(Collections.emptySet()) + .thenReturn(mkSet(TOPIC_PARTITION_A_0)) + .thenReturn(mkSet(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0)); + when(changelogReader.allChangelogsCompleted()) + .thenReturn(false); + + stateUpdater.add(task1); + stateUpdater.add(task2); + stateUpdater.add(task3); + stateUpdater.add(task4); + + verifyRestoredActiveTasks(task2, task1); + verify(task1).completeRestoration(offsetResetter); + verify(task2).completeRestoration(offsetResetter); + verify(changelogReader, atLeast(3)).restore(anyMap()); + verifyUpdatingStandbyTasks(task4, task3); + final InOrder orderVerifier = inOrder(changelogReader, task1, task2); + orderVerifier.verify(changelogReader, times(2)).enforceRestoreActive(); + orderVerifier.verify(changelogReader, times(1)).transitToUpdateStandby(); + } + + @Test + public void shouldRestoreActiveStatefulTaskThenUpdateStandbyTaskAndAgainRestoreActiveStatefulTask() throws Exception { + final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); + final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); + final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); + when(changelogReader.completedChangelogs()) + .thenReturn(Collections.emptySet()) + .thenReturn(mkSet(TOPIC_PARTITION_A_0)) + .thenReturn(mkSet(TOPIC_PARTITION_B_0)); + when(changelogReader.allChangelogsCompleted()) + .thenReturn(false); + + stateUpdater.add(task1); + stateUpdater.add(task2); + + verifyRestoredActiveTasks(task1); + verify(task1).completeRestoration(offsetResetter); + verifyUpdatingStandbyTasks(task2); + final InOrder orderVerifier = inOrder(changelogReader); + orderVerifier.verify(changelogReader, times(1)).enforceRestoreActive(); + orderVerifier.verify(changelogReader, times(1)).transitToUpdateStandby(); + + stateUpdater.add(task3); + + verifyRestoredActiveTasks(task3); + verify(task3).completeRestoration(offsetResetter); + orderVerifier.verify(changelogReader, times(1)).enforceRestoreActive(); + orderVerifier.verify(changelogReader, times(1)).transitToUpdateStandby(); } @Test public void shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithoutTask() throws Exception { final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); - final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); - final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); + final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); final String expectedMessage = "The Streams were crossed!"; final StreamsException expectedStreamsException = new StreamsException(expectedMessage); final Map updatingTasks = mkMap( mkEntry(task1.id(), task1), - mkEntry(task2.id(), task2), - mkEntry(task3.id(), task3) + mkEntry(task2.id(), task2) ); doNothing().doThrow(expectedStreamsException).doNothing().when(changelogReader).restore(updatingTasks); stateUpdater.add(task1); stateUpdater.add(task2); - stateUpdater.add(task3); final List failedTasks = getFailedTasks(1); assertEquals(1, failedTasks.size()); final ExceptionAndTasks actualFailedTasks = failedTasks.get(0); - assertEquals(3, actualFailedTasks.tasks.size()); - assertTrue(actualFailedTasks.tasks.containsAll(Arrays.asList(task1, task2, task3))); + assertEquals(2, actualFailedTasks.tasks.size()); + assertTrue(actualFailedTasks.tasks.containsAll(Arrays.asList(task1, task2))); assertTrue(actualFailedTasks.exception instanceof StreamsException); final StreamsException actualException = (StreamsException) actualFailedTasks.exception; assertFalse(actualException.taskId().isPresent()); @@ -253,7 +330,7 @@ class DefaultStateUpdaterTest { public void shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithTask() throws Exception { final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); - final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); + final StandbyTask task3 = createStandbyTaskInStateRunning(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); final String expectedMessage = "The Streams were crossed!"; final StreamsException expectedStreamsException1 = new StreamsException(expectedMessage, task1.id()); final StreamsException expectedStreamsException2 = new StreamsException(expectedMessage, task3.id()); @@ -302,7 +379,7 @@ class DefaultStateUpdaterTest { @Test public void shouldAddFailedTasksToQueueWhenRestoreThrowsTaskCorruptedException() throws Exception { final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); - final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); + final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0)); final Set expectedTaskIds = mkSet(task1.id(), task2.id()); final TaskCorruptedException taskCorruptedException = new TaskCorruptedException(expectedTaskIds); @@ -334,7 +411,7 @@ class DefaultStateUpdaterTest { @Test public void shouldAddFailedTasksToQueueWhenUncaughtExceptionIsThrown() throws Exception { final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)); - final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); + final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0)); final IllegalStateException illegalStateException = new IllegalStateException("Nobody expects the Spanish inquisition!"); final Map updatingTasks = mkMap( mkEntry(task1.id(), task1), @@ -356,6 +433,36 @@ class DefaultStateUpdaterTest { assertTrue(stateUpdater.getAllTasks().isEmpty()); } + private void verifyRestoredActiveTasks(final StreamTask... tasks) throws Exception { + final Set expectedRestoredTasks = mkSet(tasks); + final Set restoredTasks = new HashSet<>(); + waitForCondition( + () -> { + restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT))); + return restoredTasks.size() == expectedRestoredTasks.size(); + }, + VERIFICATION_TIMEOUT, + "Did not get any restored active task within the given timeout!" + ); + assertTrue(restoredTasks.containsAll(expectedRestoredTasks)); + assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(task -> task.state() == State.RESTORING).count()); + } + + private void verifyUpdatingStandbyTasks(final StandbyTask... tasks) throws Exception { + final Set expectedStandbyTasks = mkSet(tasks); + final Set standbyTasks = new HashSet<>(); + waitForCondition( + () -> { + standbyTasks.addAll(stateUpdater.getUpdatingStandbyTasks()); + return standbyTasks.size() == expectedStandbyTasks.size(); + }, + VERIFICATION_TIMEOUT, + "Did not see all standby task within the given timeout!" + ); + assertTrue(standbyTasks.containsAll(expectedStandbyTasks)); + assertEquals(expectedStandbyTasks.size(), standbyTasks.stream().filter(t -> t.state() == State.RUNNING).count()); + } + private List getFailedTasks(final int expectedCount) throws Exception { final List failedTasks = new ArrayList<>(); waitForCondition( @@ -399,6 +506,21 @@ class DefaultStateUpdaterTest { return task; } + private StandbyTask createStandbyTaskInStateRunning(final TaskId taskId, + final Collection changelogPartitions) { + final StandbyTask task = createStandbyTask(taskId, changelogPartitions); + when(task.state()).thenReturn(State.RUNNING); + return task; + } + + private StandbyTask createStandbyTask(final TaskId taskId, + final Collection changelogPartitions) { + final StandbyTask task = mock(StandbyTask.class); + setupStatefulTask(task, taskId, changelogPartitions); + when(task.isActive()).thenReturn(false); + return task; + } + private void setupStatefulTask(final Task task, final TaskId taskId, final Collection changelogPartitions) {