You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2022/05/05 23:01:00 UTC
[kafka] branch trunk updated: KAFKA-10199: Implement adding active tasks to the state updater (#12128)
This is an automated email from the ASF dual-hosted git repository.
guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/trunk by this push:
new ced5989ff6 KAFKA-10199: Implement adding active tasks to the state updater (#12128)
ced5989ff6 is described below
commit ced5989ff69f8a5e76518fdeb39f41ab20b2574f
Author: Bruno Cadonna <ca...@apache.org>
AuthorDate: Fri May 6 01:00:35 2022 +0200
KAFKA-10199: Implement adding active tasks to the state updater (#12128)
This PR adds the default implementation of the state updater. The implementation only implements adding active tasks to the state updater.
Reviewers: Guozhang Wang <wa...@gmail.com>
---
.../processor/internals/ChangelogReader.java | 2 +
.../processor/internals/DefaultStateUpdater.java | 373 +++++++++++++++++++
.../streams/processor/internals/StateUpdater.java | 19 +-
.../processor/internals/StoreChangelogReader.java | 3 +-
.../internals/DefaultStateUpdaterTest.java | 408 +++++++++++++++++++++
.../processor/internals/MockChangelogReader.java | 5 +
6 files changed, 804 insertions(+), 6 deletions(-)
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
index 9c62dd182e..38b00232c8 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
@@ -46,6 +46,8 @@ public interface ChangelogReader extends ChangelogRegister {
*/
Set<TopicPartition> completedChangelogs();
+ boolean allChangelogsCompleted();
+
/**
* Clear all partitions
*/
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
new file mode 100644
index 0000000000..0b6558d8ac
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -0,0 +1,373 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.Task.State;
+import org.slf4j.Logger;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Set;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+import java.util.stream.Collectors;
+
+public class DefaultStateUpdater implements StateUpdater {
+
+ private final static String BUG_ERROR_MESSAGE = "This indicates a bug. " +
+ "Please report at https://issues.apache.org/jira/projects/KAFKA/issues or to the dev-mailing list (https://kafka.apache.org/contact).";
+
+ private class StateUpdaterThread extends Thread {
+
+ private final ChangelogReader changelogReader;
+ private final AtomicBoolean isRunning = new AtomicBoolean(true);
+ private final java.util.function.Consumer<Set<TopicPartition>> offsetResetter;
+ private final Map<TaskId, Task> updatingTasks = new HashMap<>();
+ private final Logger log;
+
+ public StateUpdaterThread(final String name,
+ final ChangelogReader changelogReader,
+ final java.util.function.Consumer<Set<TopicPartition>> offsetResetter) {
+ super(name);
+ this.changelogReader = changelogReader;
+ this.offsetResetter = offsetResetter;
+
+ final String logPrefix = String.format("%s ", name);
+ final LogContext logContext = new LogContext(logPrefix);
+ log = logContext.logger(DefaultStateUpdater.class);
+ }
+
+ public Collection<Task> getAllUpdatingTasks() {
+ return updatingTasks.values();
+ }
+
+ @Override
+ public void run() {
+ try {
+ while (isRunning.get()) {
+ try {
+ performActionsOnTasks();
+ restoreTasks();
+ waitIfAllChangelogsCompletelyRead();
+ } catch (final InterruptedException interruptedException) {
+ return;
+ }
+ }
+ } catch (final RuntimeException anyOtherException) {
+ log.error("An unexpected error occurred within the state updater thread: " + anyOtherException);
+ final ExceptionAndTasks exceptionAndTasks = new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), anyOtherException);
+ updatingTasks.clear();
+ failedTasks.add(exceptionAndTasks);
+ isRunning.set(false);
+ } finally {
+ clear();
+ }
+ }
+
+ private void performActionsOnTasks() throws InterruptedException {
+ tasksAndActionsLock.lock();
+ try {
+ for (final TaskAndAction taskAndAction : getTasksAndActions()) {
+ final Task task = taskAndAction.task;
+ final Action action = taskAndAction.action;
+ switch (action) {
+ case ADD:
+ addTask(task);
+ break;
+ }
+ }
+ } finally {
+ tasksAndActionsLock.unlock();
+ }
+ }
+
+ private void restoreTasks() throws InterruptedException {
+ try {
+ // ToDo: Prioritize restoration of active tasks over standby tasks
+ // changelogReader.enforceRestoreActive();
+ changelogReader.restore(updatingTasks);
+ } catch (final TaskCorruptedException taskCorruptedException) {
+ handleTaskCorruptedException(taskCorruptedException);
+ } catch (final StreamsException streamsException) {
+ handleStreamsException(streamsException);
+ }
+ final Set<TopicPartition> completedChangelogs = changelogReader.completedChangelogs();
+ final List<Task> activeTasks = updatingTasks.values().stream().filter(Task::isActive).collect(Collectors.toList());
+ for (final Task task : activeTasks) {
+ endRestorationIfChangelogsCompletelyRead(task, completedChangelogs);
+ }
+ }
+
+ private void handleTaskCorruptedException(final TaskCorruptedException taskCorruptedException) {
+ final Set<TaskId> corruptedTaskIds = taskCorruptedException.corruptedTasks();
+ final Set<Task> corruptedTasks = new HashSet<>();
+ for (final TaskId taskId : corruptedTaskIds) {
+ final Task corruptedTask = updatingTasks.remove(taskId);
+ if (corruptedTask == null) {
+ throw new IllegalStateException("Task " + taskId + " is corrupted but is not updating. " + BUG_ERROR_MESSAGE);
+ }
+ corruptedTasks.add(corruptedTask);
+ }
+ failedTasks.add(new ExceptionAndTasks(corruptedTasks, taskCorruptedException));
+ }
+
+ private void handleStreamsException(final StreamsException streamsException) {
+ final ExceptionAndTasks exceptionAndTasks;
+ if (streamsException.taskId().isPresent()) {
+ exceptionAndTasks = handleStreamsExceptionWithTask(streamsException);
+ } else {
+ exceptionAndTasks = handleStreamsExceptionWithoutTask(streamsException);
+ }
+ failedTasks.add(exceptionAndTasks);
+ }
+
+ private ExceptionAndTasks handleStreamsExceptionWithTask(final StreamsException streamsException) {
+ final TaskId failedTaskId = streamsException.taskId().get();
+ if (!updatingTasks.containsKey(failedTaskId)) {
+ throw new IllegalStateException("Task " + failedTaskId + " failed but is not updating. " + BUG_ERROR_MESSAGE);
+ }
+ final Set<Task> failedTask = new HashSet<>();
+ failedTask.add(updatingTasks.get(failedTaskId));
+ updatingTasks.remove(failedTaskId);
+ return new ExceptionAndTasks(failedTask, streamsException);
+ }
+
+ private ExceptionAndTasks handleStreamsExceptionWithoutTask(final StreamsException streamsException) {
+ final ExceptionAndTasks exceptionAndTasks = new ExceptionAndTasks(new HashSet<>(updatingTasks.values()), streamsException);
+ updatingTasks.clear();
+ return exceptionAndTasks;
+ }
+
+ private void waitIfAllChangelogsCompletelyRead() throws InterruptedException {
+ if (isRunning.get() && changelogReader.allChangelogsCompleted()) {
+ tasksAndActionsLock.lock();
+ try {
+ while (tasksAndActions.isEmpty()) {
+ tasksAndActionsCondition.await();
+ }
+ } finally {
+ tasksAndActionsLock.unlock();
+ }
+ }
+ }
+
+ private void clear() {
+ tasksAndActionsLock.lock();
+ restoredActiveTasksLock.lock();
+ try {
+ tasksAndActions.clear();
+ restoredActiveTasks.clear();
+ } finally {
+ tasksAndActionsLock.unlock();
+ restoredActiveTasksLock.unlock();
+ }
+ changelogReader.clear();
+ updatingTasks.clear();
+ }
+
+ private List<TaskAndAction> getTasksAndActions() {
+ final List<TaskAndAction> tasksAndActionsToProcess = new ArrayList<>(tasksAndActions);
+ tasksAndActions.clear();
+ return tasksAndActionsToProcess;
+ }
+
+ private void addTask(final Task task) {
+ if (isStateless(task)) {
+ addTaskToRestoredTasks((StreamTask) task);
+ } else {
+ updatingTasks.put(task.id(), task);
+ }
+ }
+
+ private boolean isStateless(final Task task) {
+ return task.changelogPartitions().isEmpty() && task.isActive();
+ }
+
+ private void endRestorationIfChangelogsCompletelyRead(final Task task,
+ final Set<TopicPartition> restoredChangelogs) {
+ final Collection<TopicPartition> taskChangelogPartitions = task.changelogPartitions();
+ if (restoredChangelogs.containsAll(taskChangelogPartitions)) {
+ task.completeRestoration(offsetResetter);
+ addTaskToRestoredTasks((StreamTask) task);
+ updatingTasks.remove(task.id());
+ }
+ }
+
+ private void addTaskToRestoredTasks(final StreamTask task) {
+ restoredActiveTasksLock.lock();
+ try {
+ restoredActiveTasks.add(task);
+ restoredActiveTasksCondition.signalAll();
+ } finally {
+ restoredActiveTasksLock.unlock();
+ }
+ }
+ }
+
+ enum Action {
+ ADD
+ }
+
+ private static class TaskAndAction {
+ public final Task task;
+ public final Action action;
+
+ public TaskAndAction(final Task task, final Action action) {
+ this.task = task;
+ this.action = action;
+ }
+ }
+
+ private final Time time;
+ private final ChangelogReader changelogReader;
+ private final java.util.function.Consumer<Set<TopicPartition>> offsetResetter;
+ private final Queue<TaskAndAction> tasksAndActions = new LinkedList<>();
+ private final Lock tasksAndActionsLock = new ReentrantLock();
+ private final Condition tasksAndActionsCondition = tasksAndActionsLock.newCondition();
+ private final Queue<StreamTask> restoredActiveTasks = new LinkedList<>();
+ private final Lock restoredActiveTasksLock = new ReentrantLock();
+ private final Condition restoredActiveTasksCondition = restoredActiveTasksLock.newCondition();
+ private final BlockingQueue<ExceptionAndTasks> failedTasks = new LinkedBlockingQueue<>();
+
+ private StateUpdaterThread stateUpdaterThread = null;
+
+ public DefaultStateUpdater(final ChangelogReader changelogReader,
+ final java.util.function.Consumer<Set<TopicPartition>> offsetResetter,
+ final Time time) {
+ this.changelogReader = changelogReader;
+ this.offsetResetter = offsetResetter;
+ this.time = time;
+ }
+
+ @Override
+ public void add(final Task task) {
+ if (stateUpdaterThread == null) {
+ stateUpdaterThread = new StateUpdaterThread("state-updater", changelogReader, offsetResetter);
+ stateUpdaterThread.start();
+ }
+
+ verifyStateFor(task);
+
+ tasksAndActionsLock.lock();
+ try {
+ tasksAndActions.add(new TaskAndAction(task, Action.ADD));
+ tasksAndActionsCondition.signalAll();
+ } finally {
+ tasksAndActionsLock.unlock();
+ }
+ }
+
+ private void verifyStateFor(final Task task) {
+ if (task.isActive() && task.state() != State.RESTORING) {
+ throw new IllegalStateException("Active task " + task.id() + " is not in state RESTORING. " + BUG_ERROR_MESSAGE);
+ }
+ }
+
+ @Override
+ public void remove(final Task task) {
+ }
+
+ @Override
+ public Set<StreamTask> getRestoredActiveTasks(final Duration timeout) {
+ final long timeoutMs = timeout.toMillis();
+ final long startTime = time.milliseconds();
+ final long deadline = startTime + timeoutMs;
+ long now = startTime;
+ final Set<StreamTask> result = new HashSet<>();
+ try {
+ while (now <= deadline && result.isEmpty()) {
+ restoredActiveTasksLock.lock();
+ try {
+ while (restoredActiveTasks.isEmpty() && now <= deadline) {
+ final boolean elapsed = restoredActiveTasksCondition.await(deadline - now, TimeUnit.MILLISECONDS);
+ now = time.milliseconds();
+ }
+ while (!restoredActiveTasks.isEmpty()) {
+ result.add(restoredActiveTasks.poll());
+ }
+ } finally {
+ restoredActiveTasksLock.unlock();
+ }
+ now = time.milliseconds();
+ }
+ return result;
+ } catch (final InterruptedException e) {
+ // ignore
+ }
+ return result;
+ }
+
+ @Override
+ public List<ExceptionAndTasks> getFailedTasksAndExceptions() {
+ final List<ExceptionAndTasks> result = new ArrayList<>();
+ failedTasks.drainTo(result);
+ return result;
+ }
+
+ @Override
+ public Set<Task> getAllTasks() {
+ tasksAndActionsLock.lock();
+ restoredActiveTasksLock.lock();
+ try {
+ final Set<Task> allTasks = new HashSet<>();
+ allTasks.addAll(tasksAndActions.stream()
+ .filter(t -> t.action == Action.ADD)
+ .map(t -> t.task)
+ .collect(Collectors.toList())
+ );
+ allTasks.addAll(stateUpdaterThread.getAllUpdatingTasks());
+ allTasks.addAll(restoredActiveTasks);
+ return Collections.unmodifiableSet(allTasks);
+ } finally {
+ tasksAndActionsLock.unlock();
+ restoredActiveTasksLock.unlock();
+ }
+ }
+
+ @Override
+ public void shutdown(final Duration timeout) {
+ if (stateUpdaterThread != null) {
+ stateUpdaterThread.isRunning.set(false);
+ stateUpdaterThread.interrupt();
+ try {
+ stateUpdaterThread.join(timeout.toMillis());
+ stateUpdaterThread = null;
+ } catch (final InterruptedException e) {
+ // ignore
+ }
+ }
+ }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
index 8965abfbe9..9e98e0d2c9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
@@ -22,6 +22,16 @@ import java.util.Set;
public interface StateUpdater {
+ class ExceptionAndTasks {
+ public final Set<Task> tasks;
+ public final RuntimeException exception;
+
+ public ExceptionAndTasks(final Set<Task> tasks, final RuntimeException exception) {
+ this.tasks = tasks;
+ this.exception = exception;
+ }
+ }
+
/**
* Adds a task (active or standby) to the state updater.
*
@@ -41,17 +51,16 @@ public interface StateUpdater {
*
* @param timeout duration how long the calling thread should wait for restored active tasks
*
- * @return list of active tasks with up-to-date states
+ * @return set of active tasks with up-to-date states
*/
Set<StreamTask> getRestoredActiveTasks(final Duration timeout);
/**
- * Gets a list of exceptions thrown during restoration.
+ * Gets failed tasks and the corresponding exceptions
*
- * @return exceptions
+ * @return list of failed tasks and the corresponding exceptions
*/
- List<RuntimeException> getExceptions();
-
+ List<ExceptionAndTasks> getFailedTasksAndExceptions();
/**
* Get all tasks (active and standby) that are managed by the state updater.
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
index fdf027f2be..756bf11b0a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
@@ -394,7 +394,8 @@ public class StoreChangelogReader implements ChangelogReader {
.collect(Collectors.toSet());
}
- private boolean allChangelogsCompleted() {
+ @Override
+ public boolean allChangelogsCompleted() {
return changelogs.values().stream()
.allMatch(metadata -> metadata.changelogState == ChangelogState.COMPLETED);
}
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
new file mode 100644
index 0000000000..e94d8b1488
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTasks;
+import org.apache.kafka.streams.processor.internals.Task.State;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.test.TestUtils.waitForCondition;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.anyMap;
+import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+class DefaultStateUpdaterTest {
+
+ private final static long CALL_TIMEOUT = 1000;
+ private final static long VERIFICATION_TIMEOUT = 15000;
+ private final static TopicPartition TOPIC_PARTITION_A_0 = new TopicPartition("topicA", 0);
+ private final static TopicPartition TOPIC_PARTITION_B_0 = new TopicPartition("topicB", 0);
+ private final static TopicPartition TOPIC_PARTITION_C_0 = new TopicPartition("topicC", 0);
+ private final static TaskId TASK_0_0 = new TaskId(0, 0);
+ private final static TaskId TASK_0_2 = new TaskId(0, 2);
+ private final static TaskId TASK_1_0 = new TaskId(1, 0);
+
+ private final ChangelogReader changelogReader = mock(ChangelogReader.class);
+ private final java.util.function.Consumer<Set<TopicPartition>> offsetResetter = topicPartitions -> { };
+ private final DefaultStateUpdater stateUpdater = new DefaultStateUpdater(changelogReader, offsetResetter, Time.SYSTEM);
+
+ @AfterEach
+ public void tearDown() {
+ stateUpdater.shutdown(Duration.ofMinutes(1));
+ }
+
+ @Test
+ public void shouldShutdownStateUpdater() {
+ final StreamTask task = createStatelessTaskInStateRestoring(TASK_0_0);
+ stateUpdater.add(task);
+
+ stateUpdater.shutdown(Duration.ofMinutes(1));
+
+ verify(changelogReader).clear();
+ }
+
+ @Test
+ public void shouldShutdownStateUpdaterAndRestart() {
+ final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0);
+ stateUpdater.add(task1);
+
+ stateUpdater.shutdown(Duration.ofMinutes(1));
+
+ final StreamTask task2 = createStatelessTaskInStateRestoring(TASK_1_0);
+ stateUpdater.add(task2);
+
+ stateUpdater.shutdown(Duration.ofMinutes(1));
+
+ verify(changelogReader, times(2)).clear();
+ }
+
+ @Test
+ public void shouldThrowIfStatelessTaskNotInStateRestoring() {
+ shouldThrowIfTaskNotInStateRestoring(createStatelessTask(TASK_0_0));
+ }
+
+ @Test
+ public void shouldThrowIfStatefulTaskNotInStateRestoring() {
+ shouldThrowIfTaskNotInStateRestoring(createActiveStatefulTask(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0)));
+ }
+
+ private void shouldThrowIfTaskNotInStateRestoring(final StreamTask task) {
+ when(task.state()).thenReturn(State.CREATED);
+ assertThrows(IllegalStateException.class, () -> stateUpdater.add(task));
+ }
+
+ @Test
+ public void shouldImmediatelyAddSingleStatelessTaskToRestoredTasks() throws Exception {
+ final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0);
+ shouldImmediatelyAddStatelessTasksToRestoredTasks(task1);
+ }
+
+ @Test
+ public void shouldImmediatelyAddMultipleStatelessTasksToRestoredTasks() throws Exception {
+ final StreamTask task1 = createStatelessTaskInStateRestoring(TASK_0_0);
+ final StreamTask task2 = createStatelessTaskInStateRestoring(TASK_0_2);
+ final StreamTask task3 = createStatelessTaskInStateRestoring(TASK_1_0);
+ shouldImmediatelyAddStatelessTasksToRestoredTasks(task1, task2, task3);
+ }
+
+ private void shouldImmediatelyAddStatelessTasksToRestoredTasks(final StreamTask... tasks) throws Exception {
+ for (final StreamTask task : tasks) {
+ stateUpdater.add(task);
+ }
+
+ final Set<StreamTask> expectedRestoredTasks = mkSet(tasks);
+ final Set<StreamTask> restoredTasks = new HashSet<>();
+ waitForCondition(
+ () -> {
+ restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
+ return restoredTasks.size() == expectedRestoredTasks.size();
+ },
+ VERIFICATION_TIMEOUT,
+ "Did not get any restored active task within the given timeout!"
+ );
+ assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
+ assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(task -> task.state() == State.RESTORING).count());
+ assertTrue(stateUpdater.getAllTasks().isEmpty());
+ }
+
+ @Test
+ public void shouldRestoreSingleActiveStatefulTask() throws Exception {
+ final StreamTask task =
+ createActiveStatefulTaskInStateRestoring(TASK_0_0, Arrays.asList(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0));
+ when(changelogReader.completedChangelogs())
+ .thenReturn(Collections.emptySet())
+ .thenReturn(mkSet(TOPIC_PARTITION_A_0))
+ .thenReturn(mkSet(TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0));
+ when(changelogReader.allChangelogsCompleted())
+ .thenReturn(false)
+ .thenReturn(false)
+ .thenReturn(true);
+
+ stateUpdater.add(task);
+
+ final Set<StreamTask> expectedRestoredTasks = Collections.singleton(task);
+ final Set<StreamTask> restoredTasks = new HashSet<>();
+ waitForCondition(
+ () -> {
+ restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
+ return restoredTasks.size() == expectedRestoredTasks.size();
+ },
+ VERIFICATION_TIMEOUT,
+ "Did not get any restored active task within the given timeout!"
+ );
+ assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
+ assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(t -> t.state() == State.RESTORING).count());
+ assertTrue(stateUpdater.getAllTasks().isEmpty());
+ verify(changelogReader, atLeast(3)).restore(anyMap());
+ verify(task).completeRestoration(offsetResetter);
+ }
+
+ @Test
+ public void shouldRestoreMultipleActiveStatefulTasks() throws Exception {
+ final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
+ final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
+ final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
+ when(changelogReader.completedChangelogs())
+ .thenReturn(Collections.emptySet())
+ .thenReturn(mkSet(TOPIC_PARTITION_C_0))
+ .thenReturn(mkSet(TOPIC_PARTITION_C_0, TOPIC_PARTITION_A_0))
+ .thenReturn(mkSet(TOPIC_PARTITION_C_0, TOPIC_PARTITION_A_0, TOPIC_PARTITION_B_0));
+ when(changelogReader.allChangelogsCompleted())
+ .thenReturn(false)
+ .thenReturn(false)
+ .thenReturn(false)
+ .thenReturn(true);
+
+ stateUpdater.add(task1);
+ stateUpdater.add(task2);
+ stateUpdater.add(task3);
+
+ final Set<StreamTask> expectedRestoredTasks = mkSet(task3, task1, task2);
+ final Set<StreamTask> restoredTasks = new HashSet<>();
+ waitForCondition(
+ () -> {
+ restoredTasks.addAll(stateUpdater.getRestoredActiveTasks(Duration.ofMillis(CALL_TIMEOUT)));
+ return restoredTasks.size() == expectedRestoredTasks.size();
+ },
+ VERIFICATION_TIMEOUT,
+ "Did not get any restored active task within the given timeout!"
+ );
+ assertTrue(restoredTasks.containsAll(expectedRestoredTasks));
+ assertEquals(expectedRestoredTasks.size(), restoredTasks.stream().filter(t -> t.state() == State.RESTORING).count());
+ assertTrue(stateUpdater.getAllTasks().isEmpty());
+ verify(changelogReader, atLeast(4)).restore(anyMap());
+ verify(task3).completeRestoration(offsetResetter);
+ verify(task1).completeRestoration(offsetResetter);
+ verify(task2).completeRestoration(offsetResetter);
+ }
+
+ @Test
+ public void shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithoutTask() throws Exception {
+ final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
+ final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
+ final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
+ final String expectedMessage = "The Streams were crossed!";
+ final StreamsException expectedStreamsException = new StreamsException(expectedMessage);
+ final Map<TaskId, Task> updatingTasks = mkMap(
+ mkEntry(task1.id(), task1),
+ mkEntry(task2.id(), task2),
+ mkEntry(task3.id(), task3)
+ );
+ doNothing().doThrow(expectedStreamsException).doNothing().when(changelogReader).restore(updatingTasks);
+
+ stateUpdater.add(task1);
+ stateUpdater.add(task2);
+ stateUpdater.add(task3);
+
+ final List<ExceptionAndTasks> failedTasks = getFailedTasks(1);
+ assertEquals(1, failedTasks.size());
+ final ExceptionAndTasks actualFailedTasks = failedTasks.get(0);
+ assertEquals(3, actualFailedTasks.tasks.size());
+ assertTrue(actualFailedTasks.tasks.containsAll(Arrays.asList(task1, task2, task3)));
+ assertTrue(actualFailedTasks.exception instanceof StreamsException);
+ final StreamsException actualException = (StreamsException) actualFailedTasks.exception;
+ assertFalse(actualException.taskId().isPresent());
+ assertEquals(expectedMessage, actualException.getMessage());
+ assertTrue(stateUpdater.getAllTasks().isEmpty());
+ }
+
+ @Test
+ public void shouldAddFailedTasksToQueueWhenRestoreThrowsStreamsExceptionWithTask() throws Exception {
+ final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
+ final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
+ final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
+ final String expectedMessage = "The Streams were crossed!";
+ final StreamsException expectedStreamsException1 = new StreamsException(expectedMessage, task1.id());
+ final StreamsException expectedStreamsException2 = new StreamsException(expectedMessage, task3.id());
+ final Map<TaskId, Task> updatingTasksBeforeFirstThrow = mkMap(
+ mkEntry(task1.id(), task1),
+ mkEntry(task2.id(), task2),
+ mkEntry(task3.id(), task3)
+ );
+ final Map<TaskId, Task> updatingTasksBeforeSecondThrow = mkMap(
+ mkEntry(task2.id(), task2),
+ mkEntry(task3.id(), task3)
+ );
+ doNothing()
+ .doThrow(expectedStreamsException1)
+ .when(changelogReader).restore(updatingTasksBeforeFirstThrow);
+ doNothing()
+ .doThrow(expectedStreamsException2)
+ .when(changelogReader).restore(updatingTasksBeforeSecondThrow);
+
+ stateUpdater.add(task1);
+ stateUpdater.add(task2);
+ stateUpdater.add(task3);
+
+ final List<ExceptionAndTasks> failedTasks = getFailedTasks(2);
+ assertEquals(2, failedTasks.size());
+ final ExceptionAndTasks actualFailedTasks1 = failedTasks.get(0);
+ assertEquals(1, actualFailedTasks1.tasks.size());
+ assertTrue(actualFailedTasks1.tasks.contains(task1));
+ assertTrue(actualFailedTasks1.exception instanceof StreamsException);
+ final StreamsException actualException1 = (StreamsException) actualFailedTasks1.exception;
+ assertTrue(actualException1.taskId().isPresent());
+ assertEquals(task1.id(), actualException1.taskId().get());
+ assertEquals(expectedMessage, actualException1.getMessage());
+ final ExceptionAndTasks actualFailedTasks2 = failedTasks.get(1);
+ assertEquals(1, actualFailedTasks2.tasks.size());
+ assertTrue(actualFailedTasks2.tasks.contains(task3));
+ assertTrue(actualFailedTasks2.exception instanceof StreamsException);
+ final StreamsException actualException2 = (StreamsException) actualFailedTasks2.exception;
+ assertTrue(actualException2.taskId().isPresent());
+ assertEquals(task3.id(), actualException2.taskId().get());
+ assertEquals(expectedMessage, actualException2.getMessage());
+ assertEquals(1, stateUpdater.getAllTasks().size());
+ assertTrue(stateUpdater.getAllTasks().contains(task2));
+ }
+
+ @Test
+ public void shouldAddFailedTasksToQueueWhenRestoreThrowsTaskCorruptedException() throws Exception {
+ final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
+ final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
+ final StreamTask task3 = createActiveStatefulTaskInStateRestoring(TASK_1_0, Collections.singletonList(TOPIC_PARTITION_C_0));
+ final Set<TaskId> expectedTaskIds = mkSet(task1.id(), task2.id());
+ final TaskCorruptedException taskCorruptedException = new TaskCorruptedException(expectedTaskIds);
+ final Map<TaskId, Task> updatingTasks = mkMap(
+ mkEntry(task1.id(), task1),
+ mkEntry(task2.id(), task2),
+ mkEntry(task3.id(), task3)
+ );
+ doNothing().doThrow(taskCorruptedException).doNothing().when(changelogReader).restore(updatingTasks);
+
+ stateUpdater.add(task1);
+ stateUpdater.add(task2);
+ stateUpdater.add(task3);
+
+ final List<ExceptionAndTasks> failedTasks = getFailedTasks(1);
+ assertEquals(1, failedTasks.size());
+ final List<Task> expectedFailedTasks = Arrays.asList(task1, task2);
+ final ExceptionAndTasks actualFailedTasks = failedTasks.get(0);
+ assertEquals(2, actualFailedTasks.tasks.size());
+ assertTrue(actualFailedTasks.tasks.containsAll(expectedFailedTasks));
+ assertTrue(actualFailedTasks.exception instanceof TaskCorruptedException);
+ final TaskCorruptedException actualException = (TaskCorruptedException) actualFailedTasks.exception;
+ final Set<TaskId> corruptedTasks = actualException.corruptedTasks();
+ assertTrue(corruptedTasks.containsAll(expectedFailedTasks.stream().map(Task::id).collect(Collectors.toList())));
+ assertEquals(1, stateUpdater.getAllTasks().size());
+ assertTrue(stateUpdater.getAllTasks().contains(task3));
+ }
+
+ @Test
+ public void shouldAddFailedTasksToQueueWhenUncaughtExceptionIsThrown() throws Exception {
+ final StreamTask task1 = createActiveStatefulTaskInStateRestoring(TASK_0_0, Collections.singletonList(TOPIC_PARTITION_A_0));
+ final StreamTask task2 = createActiveStatefulTaskInStateRestoring(TASK_0_2, Collections.singletonList(TOPIC_PARTITION_B_0));
+ final IllegalStateException illegalStateException = new IllegalStateException("Nobody expects the Spanish inquisition!");
+ final Map<TaskId, Task> updatingTasks = mkMap(
+ mkEntry(task1.id(), task1),
+ mkEntry(task2.id(), task2)
+ );
+ doThrow(illegalStateException).when(changelogReader).restore(updatingTasks);
+
+ stateUpdater.add(task1);
+ stateUpdater.add(task2);
+
+ final List<ExceptionAndTasks> failedTasks = getFailedTasks(1);
+ final List<Task> expectedFailedTasks = Arrays.asList(task1, task2);
+ final ExceptionAndTasks actualFailedTasks = failedTasks.get(0);
+ assertEquals(2, actualFailedTasks.tasks.size());
+ assertTrue(actualFailedTasks.tasks.containsAll(expectedFailedTasks));
+ assertTrue(actualFailedTasks.exception instanceof IllegalStateException);
+ final IllegalStateException actualException = (IllegalStateException) actualFailedTasks.exception;
+ assertEquals(actualException.getMessage(), illegalStateException.getMessage());
+ assertTrue(stateUpdater.getAllTasks().isEmpty());
+ }
+
+ private List<ExceptionAndTasks> getFailedTasks(final int expectedCount) throws Exception {
+ final List<ExceptionAndTasks> failedTasks = new ArrayList<>();
+ waitForCondition(
+ () -> {
+ failedTasks.addAll(stateUpdater.getFailedTasksAndExceptions());
+ return failedTasks.size() >= expectedCount;
+ },
+ VERIFICATION_TIMEOUT,
+ "Did not get enough failed tasks within the given timeout!"
+ );
+
+ return failedTasks;
+ }
+
+ private StreamTask createActiveStatefulTaskInStateRestoring(final TaskId taskId,
+ final Collection<TopicPartition> changelogPartitions) {
+ final StreamTask task = createActiveStatefulTask(taskId, changelogPartitions);
+ when(task.state()).thenReturn(State.RESTORING);
+ return task;
+ }
+
+ private StreamTask createActiveStatefulTask(final TaskId taskId,
+ final Collection<TopicPartition> changelogPartitions) {
+ final StreamTask task = mock(StreamTask.class);
+ setupStatefulTask(task, taskId, changelogPartitions);
+ when(task.isActive()).thenReturn(true);
+ return task;
+ }
+
+ private StreamTask createStatelessTaskInStateRestoring(final TaskId taskId) {
+ final StreamTask task = createStatelessTask(taskId);
+ when(task.state()).thenReturn(State.RESTORING);
+ return task;
+ }
+
+ private StreamTask createStatelessTask(final TaskId taskId) {
+ final StreamTask task = mock(StreamTask.class);
+ when(task.changelogPartitions()).thenReturn(Collections.emptySet());
+ when(task.isActive()).thenReturn(true);
+ when(task.id()).thenReturn(taskId);
+ return task;
+ }
+
+ private void setupStatefulTask(final Task task,
+ final TaskId taskId,
+ final Collection<TopicPartition> changelogPartitions) {
+ when(task.changelogPartitions()).thenReturn(changelogPartitions);
+ when(task.id()).thenReturn(taskId);
+ }
+}
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
index 6ea7fc3101..d86728891c 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
@@ -59,6 +59,11 @@ public class MockChangelogReader implements ChangelogReader {
return restoringPartitions;
}
+ @Override
+ public boolean allChangelogsCompleted() {
+ return false;
+ }
+
@Override
public void clear() {
restoringPartitions.clear();