KAFKA-10199: Implement adding standby tasks to the state updater (#12200)

This PR adds adding of standby tasks to the default implementation of the state updater.

Reviewers: Guozhang Wang <wangguoz@gmail.com>
This commit is contained in:
Bruno Cadonna 2022-05-25 01:59:14 +02:00 committed by GitHub
parent 9dc332f5ca
commit 286bae4251
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 290 additions and 82 deletions

View File

@ -46,6 +46,15 @@ public interface ChangelogReader extends ChangelogRegister {
*/
Set<TopicPartition> completedChangelogs();
/**
* Returns whether all changelog partitions were completely read.
*
* Since changelog partitions for standby tasks are never completely read, this method will always return
* {@code false} if the changelog reader registered changelog partitions for standby tasks.
*
* @return {@code true} if all changelog partitions were completely read and no standby changelog partitions are read,
* {@code false} otherwise
*/
boolean allChangelogsCompleted();
/**

View File

@ -37,12 +37,14 @@ import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.stream.Collectors;
public class DefaultStateUpdater implements StateUpdater {
@ -54,13 +56,13 @@ public class DefaultStateUpdater implements StateUpdater {
private final ChangelogReader changelogReader;
private final AtomicBoolean isRunning = new AtomicBoolean(true);
private final java.util.function.Consumer<Set<TopicPartition>> offsetResetter;
private final Consumer<Set<TopicPartition>> offsetResetter;
private final Map<TaskId, Task> updatingTasks = new HashMap<>();
private final Logger log;
public StateUpdaterThread(final String name,
final ChangelogReader changelogReader,
final java.util.function.Consumer<Set<TopicPartition>> offsetResetter) {
final Consumer<Set<TopicPartition>> offsetResetter) {
super(name);
this.changelogReader = changelogReader;
this.offsetResetter = offsetResetter;
@ -74,30 +76,44 @@ public class DefaultStateUpdater implements StateUpdater {
return updatingTasks.values();
}
public Collection<StandbyTask> getUpdatingStandbyTasks() {
return updatingTasks.values().stream()
.filter(t -> !t.isActive())
.map(t -> (StandbyTask) t)
.collect(Collectors.toList());
}
public boolean onlyStandbyTasksLeft() {
return !updatingTasks.isEmpty() && updatingTasks.values().stream().allMatch(t -> !t.isActive());
}
@Override
public void run() {
log.info("State updater thread started");
try {
while (isRunning.get()) {
try {
performActionsOnTasks();
restoreTasks();
waitIfAllChangelogsCompletelyRead();
runOnce();
} catch (final InterruptedException interruptedException) {
return;
}
}
} catch (final RuntimeException anyOtherException) {
log.error("An unexpected error occurred within the state updater thread: " + anyOtherException);
final ExceptionAndTasks exceptionAndTasks = new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), anyOtherException);
updatingTasks.clear();
failedTasks.add(exceptionAndTasks);
isRunning.set(false);
handleRuntimeException(anyOtherException);
} finally {
clear();
shutdownGate.countDown();
log.info("State updater thread shutdown");
}
}
private void performActionsOnTasks() throws InterruptedException {
private void runOnce() throws InterruptedException {
performActionsOnTasks();
restoreTasks();
waitIfAllChangelogsCompletelyRead();
}
private void performActionsOnTasks() {
tasksAndActionsLock.lock();
try {
for (final TaskAndAction taskAndAction : getTasksAndActions()) {
@ -114,10 +130,8 @@ public class DefaultStateUpdater implements StateUpdater {
}
}
private void restoreTasks() throws InterruptedException {
private void restoreTasks() {
try {
// ToDo: Prioritize restoration of active tasks over standby tasks
// changelogReader.enforceRestoreActive();
changelogReader.restore(updatingTasks);
} catch (final TaskCorruptedException taskCorruptedException) {
handleTaskCorruptedException(taskCorruptedException);
@ -127,11 +141,20 @@ public class DefaultStateUpdater implements StateUpdater {
final Set<TopicPartition> completedChangelogs = changelogReader.completedChangelogs();
final List<Task> activeTasks = updatingTasks.values().stream().filter(Task::isActive).collect(Collectors.toList());
for (final Task task : activeTasks) {
endRestorationIfChangelogsCompletelyRead(task, completedChangelogs);
maybeCompleteRestoration((StreamTask) task, completedChangelogs);
}
}
private void handleRuntimeException(final RuntimeException runtimeException) {
log.error("An unexpected error occurred within the state updater thread: " + runtimeException);
final ExceptionAndTasks exceptionAndTasks = new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), runtimeException);
updatingTasks.clear();
failedTasks.add(exceptionAndTasks);
isRunning.set(false);
}
private void handleTaskCorruptedException(final TaskCorruptedException taskCorruptedException) {
log.info("Encountered task corrupted exception: ", taskCorruptedException);
final Set<TaskId> corruptedTaskIds = taskCorruptedException.corruptedTasks();
final Set<Task> corruptedTasks = new HashSet<>();
for (final TaskId taskId : corruptedTaskIds) {
@ -145,6 +168,7 @@ public class DefaultStateUpdater implements StateUpdater {
}
private void handleStreamsException(final StreamsException streamsException) {
log.info("Encountered streams exception: ", streamsException);
final ExceptionAndTasks exceptionAndTasks;
if (streamsException.taskId().isPresent()) {
exceptionAndTasks = handleStreamsExceptionWithTask(streamsException);
@ -191,8 +215,8 @@ public class DefaultStateUpdater implements StateUpdater {
tasksAndActions.clear();
restoredActiveTasks.clear();
} finally {
tasksAndActionsLock.unlock();
restoredActiveTasksLock.unlock();
tasksAndActionsLock.unlock();
}
changelogReader.clear();
updatingTasks.clear();
@ -206,9 +230,20 @@ public class DefaultStateUpdater implements StateUpdater {
private void addTask(final Task task) {
if (isStateless(task)) {
log.debug("Stateless active task " + task.id() + " was added to the state updater");
addTaskToRestoredTasks((StreamTask) task);
} else {
updatingTasks.put(task.id(), task);
if (task.isActive()) {
updatingTasks.put(task.id(), task);
log.debug("Stateful active task " + task.id() + " was added to the state updater");
changelogReader.enforceRestoreActive();
} else {
updatingTasks.put(task.id(), task);
log.debug("Standby task " + task.id() + " was added to the state updater");
if (updatingTasks.size() == 1) {
changelogReader.transitToUpdateStandby();
}
}
}
}
@ -216,13 +251,17 @@ public class DefaultStateUpdater implements StateUpdater {
return task.changelogPartitions().isEmpty() && task.isActive();
}
private void endRestorationIfChangelogsCompletelyRead(final Task task,
final Set<TopicPartition> restoredChangelogs) {
private void maybeCompleteRestoration(final StreamTask task,
final Set<TopicPartition> restoredChangelogs) {
final Collection<TopicPartition> taskChangelogPartitions = task.changelogPartitions();
if (restoredChangelogs.containsAll(taskChangelogPartitions)) {
task.completeRestoration(offsetResetter);
addTaskToRestoredTasks((StreamTask) task);
log.debug("Stateful active task " + task.id() + " completed restoration");
addTaskToRestoredTasks(task);
updatingTasks.remove(task.id());
if (onlyStandbyTasksLeft()) {
changelogReader.transitToUpdateStandby();
}
}
}
@ -230,6 +269,7 @@ public class DefaultStateUpdater implements StateUpdater {
restoredActiveTasksLock.lock();
try {
restoredActiveTasks.add(task);
log.debug("Active task " + task.id() + " was added to the restored tasks");
restoredActiveTasksCondition.signalAll();
} finally {
restoredActiveTasksLock.unlock();
@ -253,7 +293,7 @@ public class DefaultStateUpdater implements StateUpdater {
private final Time time;
private final ChangelogReader changelogReader;
private final java.util.function.Consumer<Set<TopicPartition>> offsetResetter;
private final Consumer<Set<TopicPartition>> offsetResetter;
private final Queue<TaskAndAction> tasksAndActions = new LinkedList<>();
private final Lock tasksAndActionsLock = new ReentrantLock();
private final Condition tasksAndActionsCondition = tasksAndActionsLock.newCondition();
@ -261,11 +301,12 @@ public class DefaultStateUpdater implements StateUpdater {
private final Lock restoredActiveTasksLock = new ReentrantLock();
private final Condition restoredActiveTasksCondition = restoredActiveTasksLock.newCondition();
private final BlockingQueue<ExceptionAndTasks> failedTasks = new LinkedBlockingQueue<>();
private CountDownLatch shutdownGate;
private StateUpdaterThread stateUpdaterThread = null;
public DefaultStateUpdater(final ChangelogReader changelogReader,
final java.util.function.Consumer<Set<TopicPartition>> offsetResetter,
final Consumer<Set<TopicPartition>> offsetResetter,
final Time time) {
this.changelogReader = changelogReader;
this.offsetResetter = offsetResetter;
@ -277,6 +318,7 @@ public class DefaultStateUpdater implements StateUpdater {
if (stateUpdaterThread == null) {
stateUpdaterThread = new StateUpdaterThread("state-updater", changelogReader, offsetResetter);
stateUpdaterThread.start();
shutdownGate = new CountDownLatch(1);
}
verifyStateFor(task);
@ -294,6 +336,9 @@ public class DefaultStateUpdater implements StateUpdater {
if (task.isActive() && task.state() != State.RESTORING) {
throw new IllegalStateException("Active task " + task.id() + " is not in state RESTORING. " + BUG_ERROR_MESSAGE);
}
if (!task.isActive() && task.state() != State.RUNNING) {
throw new IllegalStateException("Standby task " + task.id() + " is not in state RUNNING. " + BUG_ERROR_MESSAGE);
}
}
@Override
@ -315,9 +360,8 @@ public class DefaultStateUpdater implements StateUpdater {
final boolean elapsed = restoredActiveTasksCondition.await(deadline - now, TimeUnit.MILLISECONDS);
now = time.milliseconds();
}
while (!restoredActiveTasks.isEmpty()) {
result.add(restoredActiveTasks.poll());
}
result.addAll(restoredActiveTasks);
restoredActiveTasks.clear();
} finally {
restoredActiveTasksLock.unlock();
}
@ -352,21 +396,44 @@ public class DefaultStateUpdater implements StateUpdater {
allTasks.addAll(restoredActiveTasks);
return Collections.unmodifiableSet(allTasks);
} finally {
tasksAndActionsLock.unlock();
restoredActiveTasksLock.unlock();
tasksAndActionsLock.unlock();
}
}
@Override
public Set<StandbyTask> getStandbyTasks() {
tasksAndActionsLock.lock();
try {
final Set<StandbyTask> standbyTasks = new HashSet<>();
standbyTasks.addAll(tasksAndActions.stream()
.filter(t -> t.action == Action.ADD)
.filter(t -> !t.task.isActive())
.map(t -> (StandbyTask) t.task)
.collect(Collectors.toList())
);
standbyTasks.addAll(getUpdatingStandbyTasks());
return Collections.unmodifiableSet(standbyTasks);
} finally {
tasksAndActionsLock.unlock();
}
}
public Set<StandbyTask> getUpdatingStandbyTasks() {
return Collections.unmodifiableSet(new HashSet<>(stateUpdaterThread.getUpdatingStandbyTasks()));
}
@Override
public void shutdown(final Duration timeout) {
if (stateUpdaterThread != null) {
stateUpdaterThread.isRunning.set(false);
stateUpdaterThread.interrupt();
try {
stateUpdaterThread.join(timeout.toMillis());
if (!shutdownGate.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) {
throw new StreamsException("State updater thread did not shutdown within the timeout");
}
stateUpdaterThread = null;
} catch (final InterruptedException e) {
// ignore
} catch (final InterruptedException ignored) {
}
}
}

View File

@ -65,14 +65,24 @@ public interface StateUpdater {
/**
* Get all tasks (active and standby) that are managed by the state updater.
*
* @return list of tasks managed by the state updater
* @return set of tasks managed by the state updater
*/
Set<Task> getAllTasks();
/**
* Get standby tasks that are managed by the state updater.
*
* @return set of standby tasks managed by the state updater
*/
Set<StandbyTask> getStandbyTasks();
/**
* Shuts down the state updater.
*
* @param timeout duration how long to wait until the state updater is shut down
*
* @throws
* org.apache.kafka.streams.errors.StreamsException if the state updater thread cannot shutdown within the timeout
*/
void shutdown(final Duration timeout);
}

View File

@ -25,6 +25,7 @@ import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTas
import org.apache.kafka.streams.processor.internals.Task.State;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.mockito.InOrder;
import java.time.Duration;
import java.util.ArrayList;
@ -49,7 +50,10 @@ import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -61,9 +65,11 @@ class DefaultStateUpdaterTest {
private final static TopicPartition TOPIC_PARTITION_A_0 = new TopicPartition("topicA", 0);
private final static TopicPartition TOPIC_PARTITION_B_0 = new TopicPartition("topicB", 0);
private final static TopicPartition TOPIC_PARTITION_C_0 = new TopicPartition("topicC", 0);
private final static TopicPartition TOPIC_PARTITION_D_0 = new TopicPartition("topicD", 0);
private final static TaskId TASK_0_0 = new TaskId(0, 0);
private final static TaskId TASK_0_2 = new TaskId(0, 2);
private final static TaskId TASK_1_0 = new TaskId(1, 0);
private final static TaskId TASK_1_1 = new TaskId(1, 1);
private final ChangelogReader changelogReader = mock(ChangelogReader.class);
private final java.util.function.Consumer<Set<TopicPartition>> offsetResetter = topicPartitions -> { };
@ -101,17 +107,31 @@ class DefaultStateUpdaterTest {
@Test
public void shouldThrowIfStatelessTaskNotInStateRestoring() {
shouldThrowIfTaskNotInStateRestoring(createStatelessTask(TASK_0_0));
shouldThrowIfActiveTaskNotInStateRestoring(createStatelessTask(TASK_0_0));
}
@Test
public void shouldThrowIfStatefulTaskNotInStateRestoring() {
shouldThrowIfTaskNotInStateRestoring(createActiveStatefulTask(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)));
shouldThrowIfActiveTaskNotInStateRestoring(createActiveStatefulTask(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)));
}
private void shouldThrowIfTaskNotInStateRestoring(final StreamTask task) {
when(task.state()).thenReturn(State.CREATED);
assertThrows(IllegalStateException.class, () -> stateUpdater.add(task));
private void shouldThrowIfActiveTaskNotInStateRestoring(final StreamTask task) {
shouldThrowIfTaskNotInGivenState(task, State.RESTORING);
}
@Test
public void shouldThrowIfStandbyTaskNotInStateRunning() {
final StandbyTask task = createStandbyTask(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_B_0));
shouldThrowIfTaskNotInGivenState(task, State.RUNNING);
}
private void shouldThrowIfTaskNotInGivenState(final Task task, final State correctState) {
for (final State state : State.values()) {
if (state != correctState) {
when(task.state()).thenReturn(state);
assertThrows(IllegalStateException.class, () -> stateUpdater.add(task));
}
}
}
@Test
@ -133,18 +153,7 @@ class DefaultStateUpdaterTest {
stateUpdater.add(task);
}
final Set<StreamTask> expectedRestoredTasks = mkSet(tasks);
final Set<StreamTask> restoredTasks = new HashSet<>();
waitForCondition(
() -> {
restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
return restoredTasks.size() == expectedRestoredTasks.size();
},
VERIFICATION_TIMEOUT,
"Did not get any restored active task within the given timeout!"
);
assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(task -> task.state() == State.RESTORING).count());
verifyRestoredActiveTasks(tasks);
assertTrue(stateUpdater.getAllTasks().isEmpty());
}
@ -163,21 +172,12 @@ class DefaultStateUpdaterTest {
stateUpdater.add(task);
final Set<StreamTask> expectedRestoredTasks = Collections.singleton(task);
final Set<StreamTask> restoredTasks = new HashSet<>();
waitForCondition(
() -> {
restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
return restoredTasks.size() == expectedRestoredTasks.size();
},
VERIFICATION_TIMEOUT,
"Did not get any restored active task within the given timeout!"
);
assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(t -> t.state() == State.RESTORING).count());
verifyRestoredActiveTasks(task);
assertTrue(stateUpdater.getAllTasks().isEmpty());
verify(changelogReader, atLeast(3)).restore(anyMap());
verify(changelogReader, times(1)).enforceRestoreActive();
verify(changelogReader, atLeast(1)).restore(anyMap());
verify(task).completeRestoration(offsetResetter);
verify(changelogReader, never()).transitToUpdateStandby();
}
@Test
@ -200,48 +200,125 @@ class DefaultStateUpdaterTest {
stateUpdater.add(task2);
stateUpdater.add(task3);
final Set<StreamTask> expectedRestoredTasks = mkSet(task3, task1, task2);
final Set<StreamTask> restoredTasks = new HashSet<>();
waitForCondition(
() -> {
restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
return restoredTasks.size() == expectedRestoredTasks.size();
},
VERIFICATION_TIMEOUT,
"Did not get any restored active task within the given timeout!"
);
assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(t -> t.state() == State.RESTORING).count());
verifyRestoredActiveTasks(task3, task1, task2);
assertTrue(stateUpdater.getAllTasks().isEmpty());
verify(changelogReader, times(3)).enforceRestoreActive();
verify(changelogReader, atLeast(4)).restore(anyMap());
verify(task3).completeRestoration(offsetResetter);
verify(task1).completeRestoration(offsetResetter);
verify(task2).completeRestoration(offsetResetter);
verify(changelogReader, never()).transitToUpdateStandby();
}
@Test
public void shouldUpdateSingleStandbyTask() throws Exception {
final StandbyTask task = createStandbyTaskInStateRunning(
TASK_0_0,
Arrays.asList(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0)
);
shouldUpdateStandbyTasks(task);
}
@Test
public void shouldUpdateMultipleStandbyTasks() throws Exception {
final StandbyTask task1 = createStandbyTaskInStateRunning(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
final StandbyTask task3 = createStandbyTaskInStateRunning(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
shouldUpdateStandbyTasks(task1, task2, task3);
}
private void shouldUpdateStandbyTasks(final StandbyTask... tasks) throws Exception {
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
when(changelogReader.allChangelogsCompleted()).thenReturn(false);
for (final StandbyTask task : tasks) {
stateUpdater.add(task);
}
verifyUpdatingStandbyTasks(tasks);
verify(changelogReader, times(1)).transitToUpdateStandby();
verify(changelogReader, timeout(VERIFICATION_TIMEOUT).atLeast(1)).restore(anyMap());
verify(changelogReader, never()).enforceRestoreActive();
}
@Test
public void shouldRestoreActiveStatefulTasksAndUpdateStandbyTasks() throws Exception {
final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
final StandbyTask task3 = createStandbyTaskInStateRunning(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
final StandbyTask task4 = createStandbyTaskInStateRunning(TASK_1_1, Collections.singletonList(TOPIC_PARTITION_D_0));
when(changelogReader.completedChangelogs())
.thenReturn(Collections.emptySet())
.thenReturn(mkSet(TOPIC_PARTITION_A_0))
.thenReturn(mkSet(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0));
when(changelogReader.allChangelogsCompleted())
.thenReturn(false);
stateUpdater.add(task1);
stateUpdater.add(task2);
stateUpdater.add(task3);
stateUpdater.add(task4);
verifyRestoredActiveTasks(task2, task1);
verify(task1).completeRestoration(offsetResetter);
verify(task2).completeRestoration(offsetResetter);
verify(changelogReader, atLeast(3)).restore(anyMap());
verifyUpdatingStandbyTasks(task4, task3);
final InOrder orderVerifier = inOrder(changelogReader, task1, task2);
orderVerifier.verify(changelogReader, times(2)).enforceRestoreActive();
orderVerifier.verify(changelogReader, times(1)).transitToUpdateStandby();
}
@Test
public void shouldRestoreActiveStatefulTaskThenUpdateStandbyTaskAndAgainRestoreActiveStatefulTask() throws Exception {
final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
when(changelogReader.completedChangelogs())
.thenReturn(Collections.emptySet())
.thenReturn(mkSet(TOPIC_PARTITION_A_0))
.thenReturn(mkSet(TOPIC_PARTITION_B_0));
when(changelogReader.allChangelogsCompleted())
.thenReturn(false);
stateUpdater.add(task1);
stateUpdater.add(task2);
verifyRestoredActiveTasks(task1);
verify(task1).completeRestoration(offsetResetter);
verifyUpdatingStandbyTasks(task2);
final InOrder orderVerifier = inOrder(changelogReader);
orderVerifier.verify(changelogReader, times(1)).enforceRestoreActive();
orderVerifier.verify(changelogReader, times(1)).transitToUpdateStandby();
stateUpdater.add(task3);
verifyRestoredActiveTasks(task3);
verify(task3).completeRestoration(offsetResetter);
orderVerifier.verify(changelogReader, times(1)).enforceRestoreActive();
orderVerifier.verify(changelogReader, times(1)).transitToUpdateStandby();
}
@Test
public void shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithoutTask() throws Exception {
final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
final String expectedMessage = "The Streams were crossed!";
final StreamsException expectedStreamsException = new StreamsException(expectedMessage);
final Map<TaskId, Task> updatingTasks = mkMap(
mkEntry(task1.id(), task1),
mkEntry(task2.id(), task2),
mkEntry(task3.id(), task3)
mkEntry(task2.id(), task2)
);
doNothing().doThrow(expectedStreamsException).doNothing().when(changelogReader).restore(updatingTasks);
stateUpdater.add(task1);
stateUpdater.add(task2);
stateUpdater.add(task3);
final List<ExceptionAndTasks> failedTasks = getFailedTasks(1);
assertEquals(1, failedTasks.size());
final ExceptionAndTasks actualFailedTasks = failedTasks.get(0);
assertEquals(3, actualFailedTasks.tasks.size());
assertTrue(actualFailedTasks.tasks.containsAll(Arrays.asList(task1, task2, task3)));
assertEquals(2, actualFailedTasks.tasks.size());
assertTrue(actualFailedTasks.tasks.containsAll(Arrays.asList(task1, task2)));
assertTrue(actualFailedTasks.exception instanceof StreamsException);
final StreamsException actualException = (StreamsException) actualFailedTasks.exception;
assertFalse(actualException.taskId().isPresent());
@ -253,7 +330,7 @@ class DefaultStateUpdaterTest {
public void shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithTask() throws Exception {
final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
final StandbyTask task3 = createStandbyTaskInStateRunning(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
final String expectedMessage = "The Streams were crossed!";
final StreamsException expectedStreamsException1 = new StreamsException(expectedMessage, task1.id());
final StreamsException expectedStreamsException2 = new StreamsException(expectedMessage, task3.id());
@ -302,7 +379,7 @@ class DefaultStateUpdaterTest {
@Test
public void shouldAddFailedTasksToQueueWhenRestoreThrowsTaskCorruptedException() throws Exception {
final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
final Set<TaskId> expectedTaskIds = mkSet(task1.id(), task2.id());
final TaskCorruptedException taskCorruptedException = new TaskCorruptedException(expectedTaskIds);
@ -334,7 +411,7 @@ class DefaultStateUpdaterTest {
@Test
public void shouldAddFailedTasksToQueueWhenUncaughtExceptionIsThrown() throws Exception {
final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
final StandbyTask task2 = createStandbyTaskInStateRunning(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
final IllegalStateException illegalStateException = new IllegalStateException("Nobody expects the Spanish inquisition!");
final Map<TaskId, Task> updatingTasks = mkMap(
mkEntry(task1.id(), task1),
@ -356,6 +433,36 @@ class DefaultStateUpdaterTest {
assertTrue(stateUpdater.getAllTasks().isEmpty());
}
private void verifyRestoredActiveTasks(final StreamTask... tasks) throws Exception {
final Set<StreamTask> expectedRestoredTasks = mkSet(tasks);
final Set<StreamTask> restoredTasks = new HashSet<>();
waitForCondition(
() -> {
restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
return restoredTasks.size() == expectedRestoredTasks.size();
},
VERIFICATION_TIMEOUT,
"Did not get any restored active task within the given timeout!"
);
assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(task -> task.state() == State.RESTORING).count());
}
private void verifyUpdatingStandbyTasks(final StandbyTask... tasks) throws Exception {
final Set<StandbyTask> expectedStandbyTasks = mkSet(tasks);
final Set<StandbyTask> standbyTasks = new HashSet<>();
waitForCondition(
() -> {
standbyTasks.addAll(stateUpdater.getUpdatingStandbyTasks());
return standbyTasks.size() == expectedStandbyTasks.size();
},
VERIFICATION_TIMEOUT,
"Did not see all standby task within the given timeout!"
);
assertTrue(standbyTasks.containsAll(expectedStandbyTasks));
assertEquals(expectedStandbyTasks.size(), standbyTasks.stream().filter(t -> t.state() == State.RUNNING).count());
}
private List<ExceptionAndTasks> getFailedTasks(final int expectedCount) throws Exception {
final List<ExceptionAndTasks> failedTasks = new ArrayList<>();
waitForCondition(
@ -399,6 +506,21 @@ class DefaultStateUpdaterTest {
return task;
}
private StandbyTask createStandbyTaskInStateRunning(final TaskId taskId,
final Collection<TopicPartition> changelogPartitions) {
final StandbyTask task = createStandbyTask(taskId, changelogPartitions);
when(task.state()).thenReturn(State.RUNNING);
return task;
}
private StandbyTask createStandbyTask(final TaskId taskId,
final Collection<TopicPartition> changelogPartitions) {
final StandbyTask task = mock(StandbyTask.class);
setupStatefulTask(task, taskId, changelogPartitions);
when(task.isActive()).thenReturn(false);
return task;
}
private void setupStatefulTask(final Task task,
final TaskId taskId,
final Collection<TopicPartition> changelogPartitions) {