You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by mj...@apache.org on 2020/08/05 21:14:53 UTC

[kafka] branch trunk updated: KAFKA-9274: Remove `retries` for global task (#9047)

This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new b351493  KAFKA-9274: Remove `retries` for global task (#9047)
b351493 is described below

commit b351493543b7e26aa345df3b568d0dc08a8c8d91
Author: Matthias J. Sax <ma...@confluent.io>
AuthorDate: Wed Aug 5 14:14:18 2020 -0700

    KAFKA-9274: Remove `retries` for global task (#9047)
    
     - part of KIP-572
     - removed the usage of `retries` in `GlobalStateManger`
     - instead of retries the new `task.timeout.ms` config is used
    
    Reviewers: John Roesler <jo...@confluent.io>, Boyang Chen <bo...@confluent.io>, Guozhang Wang <gu...@confluent.io>
---
 checkstyle/suppressions.xml                        |   2 +-
 docs/streams/developer-guide/config-streams.html   |  28 +-
 docs/streams/upgrade-guide.html                    |  11 +-
 docs/upgrade.html                                  |   4 +-
 .../org/apache/kafka/streams/StreamsConfig.java    |  10 +-
 .../streams/internals/QuietStreamsConfig.java      |  33 --
 .../streams/processor/internals/ClientUtils.java   |  17 +-
 .../internals/GlobalStateManagerImpl.java          | 204 ++++---
 .../processor/internals/GlobalStreamThread.java    |   7 +-
 .../processor/internals/InternalTopicManager.java  |   3 +-
 .../assignment/AssignorConfiguration.java          |   4 +-
 .../internals/GlobalStateManagerImplTest.java      | 660 +++++++++++++++++++--
 .../kafka/test/InternalMockProcessorContext.java   |   8 +-
 .../org/apache/kafka/test/NoOpReadOnlyStore.java   |   1 +
 .../apache/kafka/streams/TopologyTestDriver.java   |  10 +-
 .../streams/processor/MockProcessorContext.java    |   4 +-
 16 files changed, 802 insertions(+), 204 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 57cf079..685d10f 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -162,7 +162,7 @@
               files="StreamsMetricsImpl.java"/>
 
     <suppress checks="NPathComplexity"
-              files="(KafkaStreams|StreamsPartitionAssignor|StreamThread|TaskManager).java"/>
+              files="(KafkaStreams|StreamsPartitionAssignor|StreamThread|TaskManager|GlobalStateManagerImpl).java"/>
 
     <suppress checks="(FinalLocalVariable|UnnecessaryParentheses|BooleanExpressionComplexity|CyclomaticComplexity|WhitespaceAfter|LocalVariableName)"
               files="Murmur3.java"/>
diff --git a/docs/streams/developer-guide/config-streams.html b/docs/streams/developer-guide/config-streams.html
index 8388476..46c46d1 100644
--- a/docs/streams/developer-guide/config-streams.html
+++ b/docs/streams/developer-guide/config-streams.html
@@ -203,7 +203,7 @@
           </tr>
           <tr class="row-even"><td>commit.interval.ms</td>
             <td>Low</td>
-            <td colspan="2">The frequency with which to save the position (offsets in source topics) of tasks.</td>
+            <td colspan="2">The frequency in milliseconds with which to save the position (offsets in source topics) of tasks.</td>
             <td>30000 milliseconds</td>
           </tr>
           <tr class="row-odd"><td>default.deserialization.exception.handler</td>
@@ -243,8 +243,8 @@
           </tr>
           <tr class="row-even"><td>max.task.idle.ms</td>
             <td>Medium</td>
-            <td colspan="2">Maximum amount of time a stream task will stay idle while waiting for all partitions to contain data and avoid potential out-of-order record
-              processing across multiple input streams.</td>
+            <td colspan="2">Maximum amount of time in milliseconds a stream task will stay idle while waiting for all partitions to contain data
+              and avoid potential out-of-order record processing across multiple input streams.</td>
             <td>0 milliseconds</td>
           </tr>
           <tr class="row-odd"><td>max.warmup.replicas</td>
@@ -269,8 +269,8 @@
           </tr>
           <tr class="row-odd"><td>metrics.sample.window.ms</td>
             <td>Low</td>
-            <td colspan="2">The window of time a metrics sample is computed over.</td>
-            <td>30000 milliseconds</td>
+            <td colspan="2">The window of time in milliseconds a metrics sample is computed over.</td>
+            <td>30000 milliseconds (30 seconds)</td>
           </tr>
           <tr class="row-even"><td>num.standby.replicas</td>
             <td>Medium</td>
@@ -289,7 +289,7 @@
           </tr>
           <tr class="row-odd"><td>probing.rebalance.interval.ms</td>
             <td>Low</td>
-            <td colspan="2">The maximum time to wait before triggering a rebalance to probe for warmup replicas that have sufficiently caught up.</td>
+            <td colspan="2">The maximum time in milliseconds to wait before triggering a rebalance to probe for warmup replicas that have sufficiently caught up.</td>
             <td>600000 milliseconds (10 minutes)</td>
           </tr>
           <tr class="row-even"><td>processing.guarantee</td>
@@ -308,15 +308,10 @@
             <td colspan="2">The replication factor for changelog topics and repartition topics created by the application.</td>
             <td>1</td>
           </tr>
-          <tr class="row-odd"><td>retries</td>
-              <td>Medium</td>
-              <td colspan="2">The number of retries for broker requests that return a retryable error. </td>
-              <td>0</td>
-          </tr>
           <tr class="row-even"><td>retry.backoff.ms</td>
               <td>Medium</td>
               <td colspan="2">The amount of time in milliseconds, before a request is retried. This applies if the <code class="docutils literal"><span class="pre">retries</span></code> parameter is configured to be greater than 0. </td>
-              <td>100</td>
+              <td>100 milliseconds</td>
           </tr>
           <tr class="row-odd"><td>rocksdb.config.setter</td>
             <td>Medium</td>
@@ -326,13 +321,18 @@
           <tr class="row-even"><td>state.cleanup.delay.ms</td>
             <td>Low</td>
             <td colspan="2">The amount of time in milliseconds to wait before deleting state when a partition has migrated.</td>
-            <td>600000 milliseconds</td>
+            <td>600000 milliseconds (10 minutes)</td>
           </tr>
           <tr class="row-odd"><td>state.dir</td>
             <td>High</td>
             <td colspan="2">Directory location for state stores.</td>
             <td><code class="docutils literal"><span class="pre">/tmp/kafka-streams</span></code></td>
           </tr>
+          <tr class="row-odd"><td>task.timeout.ms</td>
+            <td>Medium</td>
+            <td colspan="2">The maximum amount of time in milliseconds a task might stall due to internal errors and retries until an error is raised. For a timeout of <code>0 ms</code>, a task would raise an error for the first internal error. For any timeout larger than <code>0 ms</code>, a task will retry at least once before an error is raised.</td>
+            <td>300000 milliseconds (5 minutes)</td>
+          </tr>
           <tr class="row-even"><td>topology.optimization</td>
             <td>Medium</td>
             <td colspan="2">A configuration telling Kafka Streams if it should optimize the topology</td>
@@ -346,7 +346,7 @@
           <tr class="row-even"><td>windowstore.changelog.additional.retention.ms</td>
             <td>Low</td>
             <td colspan="2">Added to a windows maintainMs to ensure data is not deleted from the log prematurely. Allows for clock drift.</td>
-            <td>86400000 milliseconds = 1 day</td>
+            <td>86400000 milliseconds (1 day)</td>
           </tr>
           </tbody>
         </table>
diff --git a/docs/streams/upgrade-guide.html b/docs/streams/upgrade-guide.html
index db91d16..8480e5e 100644
--- a/docs/streams/upgrade-guide.html
+++ b/docs/streams/upgrade-guide.html
@@ -95,11 +95,12 @@
     </p>
 
     <p>
-        The configuration parameter <code>retries</code> is deprecated in favor of a the new parameter <code>task.timeout.ms</code>.
-        Kafka Streams runtime ignores <code>retries</code> if set, however, if would still forward the parameter
-        to it's internal clients. Note though, that <code>retries</code> is deprecated for the producer and admin client, too.
-        Thus, instead of setting <code>retries</code>, you should configure the corresponding client timeouts.
-
+        The configuration parameter <code>retries</code> is deprecated in favor of the new parameter <code>task.timeout.ms</code>.
+        Kafka Streams' runtime ignores <code>retries</code> if set, however, it would still forward the parameter
+        to its internal clients. Note though, that <code>retries</code> is deprecated for the producer and admin client, too.
+        Thus, instead of setting <code>retries</code>, you should configure the corresponding client timeouts, namely
+        <code>delivery.timeout.ms</code> and <code>max.block.ms</code> for the producer and
+        <code>default.api.timeout.ms</code> for the admin client.
     </p>
 
     <h3><a id="streams_api_changes_260" href="#streams_api_changes_260">Streams API changes in 2.6.0</a></h3>
diff --git a/docs/upgrade.html b/docs/upgrade.html
index 863705d..edb6b19 100644
--- a/docs/upgrade.html
+++ b/docs/upgrade.html
@@ -23,8 +23,8 @@
 <ul>
     <li>The configuration parameter <code>retries</code> is deprecated for the producer, admin, and Kafka Streams clients
         via <a href="https://cwiki.apache.org/confluence/display/KAFKA/KIP-572%3A+Improve+timeouts+and+retries+in+Kafka+Streams">KIP-572</a>.
-        You should use the producer's <code>delivery.timeout.ms</code>, admin's <code>default.api.timeout.ms</code>, and
-        Kafka Streams' new <code>task.timeout.ms</code> parameters instead.
+        You should use the producer's <code>delivery.timeout.ms</code> and <code>max.block.ms</code>, admin's
+        <code>default.api.timeout.ms</code>, and Kafka Streams' new <code>task.timeout.ms</code> parameters instead.
         Note that parameter <code>retry.backoff.ms</code> is not impacted by this change.
     </li>
 </ul>
diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
index 15d0a15..4ad6c82 100644
--- a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
+++ b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
@@ -357,7 +357,7 @@ public class StreamsConfig extends AbstractConfig {
     /** {@code commit.interval.ms} */
     @SuppressWarnings("WeakerAccess")
     public static final String COMMIT_INTERVAL_MS_CONFIG = "commit.interval.ms";
-    private static final String COMMIT_INTERVAL_MS_DOC = "The frequency with which to save the position of the processor." +
+    private static final String COMMIT_INTERVAL_MS_DOC = "The frequency in milliseconds with which to save the position of the processor." +
         " (Note, if <code>processing.guarantee</code> is set to <code>" + EXACTLY_ONCE + "</code>, the default value is <code>" + EOS_DEFAULT_COMMIT_INTERVAL_MS + "</code>," +
         " otherwise the default value is <code>" + DEFAULT_COMMIT_INTERVAL_MS + "</code>.";
 
@@ -408,7 +408,7 @@ public class StreamsConfig extends AbstractConfig {
 
     /** {@code max.task.idle.ms} */
     public static final String MAX_TASK_IDLE_MS_CONFIG = "max.task.idle.ms";
-    private static final String MAX_TASK_IDLE_MS_DOC = "Maximum amount of time a stream task will stay idle when not all of its partition buffers contain records," +
+    private static final String MAX_TASK_IDLE_MS_DOC = "Maximum amount of time in milliseconds a stream task will stay idle when not all of its partition buffers contain records," +
         " to avoid potential out-of-order record processing across multiple input streams.";
 
     /** {@code max.warmup.replicas} */
@@ -454,8 +454,8 @@ public class StreamsConfig extends AbstractConfig {
 
     /** {@code probing.rebalance.interval.ms} */
     public static final String PROBING_REBALANCE_INTERVAL_MS_CONFIG = "probing.rebalance.interval.ms";
-    private static final String PROBING_REBALANCE_INTERVAL_MS_DOC = "The maximum time to wait before triggering a rebalance to probe for warmup replicas that have finished warming up and are ready to become active. Probing rebalances " +
-                                                                        "will continue to be triggered until the assignment is balanced. Must be at least 1 minute.";
+    private static final String PROBING_REBALANCE_INTERVAL_MS_DOC = "The maximum time in milliseconds to wait before triggering a rebalance to probe for warmup replicas that have finished warming up and are ready to become active." +
+        " Probing rebalances will continue to be triggered until the assignment is balanced. Must be at least 1 minute.";
 
     /** {@code processing.guarantee} */
     @SuppressWarnings("WeakerAccess")
@@ -529,7 +529,7 @@ public class StreamsConfig extends AbstractConfig {
 
     /** {@code task.timeout.ms} */
     public static final String TASK_TIMEOUT_MS_CONFIG = "task.timeout.ms";
-    public static final String TASK_TIMEOUT_MS_DOC = "Max amount of time a task might stall due to internal errors and retries until an error is raised. " +
+    public static final String TASK_TIMEOUT_MS_DOC = "The maximum amount of time in milliseconds a task might stall due to internal errors and retries until an error is raised. " +
         "For a timeout of 0ms, a task would raise an error for the first internal error. " +
         "For any timeout larger than 0ms, a task will retry at least once before an error is raised.";
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/internals/QuietStreamsConfig.java b/streams/src/main/java/org/apache/kafka/streams/internals/QuietStreamsConfig.java
deleted file mode 100644
index 6132668..0000000
--- a/streams/src/main/java/org/apache/kafka/streams/internals/QuietStreamsConfig.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.streams.internals;
-
-import org.apache.kafka.streams.StreamsConfig;
-
-import java.util.Map;
-
-/**
- * A {@link StreamsConfig} that does not log its configuration on construction.
- *
- * This producer cleaner output for unit tests using the {@code test-utils},
- * since logging the config is not really valuable in this context.
- */
-public class QuietStreamsConfig extends StreamsConfig {
-    public QuietStreamsConfig(final Map<?, ?> props) {
-        super(props, false);
-    }
-}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java
index 44d3484..58cce5c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java
@@ -21,6 +21,7 @@ import org.apache.kafka.clients.admin.AdminClientConfig;
 import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
 import org.apache.kafka.clients.admin.OffsetSpec;
 import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.KafkaFuture;
 import org.apache.kafka.common.Metric;
@@ -30,6 +31,8 @@ import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.TaskId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.Collection;
 import java.util.Collections;
@@ -39,12 +42,22 @@ import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.function.Function;
 import java.util.stream.Collectors;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 public class ClientUtils {
     private static final Logger LOG = LoggerFactory.getLogger(ClientUtils.class);
 
+    public static class QuietStreamsConfig extends StreamsConfig {
+        public QuietStreamsConfig(final Map<?, ?> props) {
+            super(props, false);
+        }
+    }
+
+    public static class QuietConsumerConfig extends ConsumerConfig {
+        public QuietConsumerConfig(final Map<String, Object> props) {
+            super(props, false);
+        }
+    }
+
     public static final class QuietAdminClientConfig extends AdminClientConfig {
         QuietAdminClientConfig(final StreamsConfig streamsConfig) {
             // If you just want to look up admin configs, you don't care about the clientId
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
index 942308b..0faf6f0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
@@ -17,14 +17,16 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.serialization.ByteArrayDeserializer;
 import org.apache.kafka.common.utils.FixedOrderMap;
 import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.ProcessorStateException;
@@ -48,6 +50,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.function.Supplier;
 
 import static org.apache.kafka.streams.processor.internals.StateManagerUtil.CHECKPOINT_FILE_NAME;
 import static org.apache.kafka.streams.processor.internals.StateManagerUtil.converterForStore;
@@ -57,7 +60,10 @@ import static org.apache.kafka.streams.processor.internals.StateManagerUtil.conv
  * of Global State Stores. There is only ever 1 instance of this class per Application Instance.
  */
 public class GlobalStateManagerImpl implements GlobalStateManager {
+    private final static long NO_DEADLINE = -1L;
+
     private final Logger log;
+    private final Time time;
     private final Consumer<byte[], byte[]> globalConsumer;
     private final File baseDir;
     private final StateDirectory stateDirectory;
@@ -65,22 +71,22 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
     private final FixedOrderMap<String, Optional<StateStore>> globalStores = new FixedOrderMap<>();
     private final StateRestoreListener stateRestoreListener;
     private InternalProcessorContext globalProcessorContext;
-    private final int retries;
-    private final long retryBackoffMs;
-    private final Duration pollTime;
+    private final Duration requestTimeoutPlusTaskTimeout;
+    private final long taskTimeoutMs;
     private final Set<String> globalNonPersistentStoresTopics = new HashSet<>();
     private final OffsetCheckpoint checkpointFile;
     private final Map<TopicPartition, Long> checkpointFileCache;
     private final Map<String, String> storeToChangelogTopic;
     private final List<StateStore> globalStateStores;
 
-    @SuppressWarnings("deprecation") // TODO: remove in follow up PR when `RETRIES` is removed
     public GlobalStateManagerImpl(final LogContext logContext,
+                                  final Time time,
                                   final ProcessorTopology topology,
                                   final Consumer<byte[], byte[]> globalConsumer,
                                   final StateDirectory stateDirectory,
                                   final StateRestoreListener stateRestoreListener,
                                   final StreamsConfig config) {
+        this.time = time;
         storeToChangelogTopic = topology.storeToChangelogTopic();
         globalStateStores = topology.globalStateStores();
         baseDir = stateDirectory.globalStateDir();
@@ -98,9 +104,16 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
         this.globalConsumer = globalConsumer;
         this.stateDirectory = stateDirectory;
         this.stateRestoreListener = stateRestoreListener;
-        retries = config.getInt(StreamsConfig.RETRIES_CONFIG);
-        retryBackoffMs = config.getLong(StreamsConfig.RETRY_BACKOFF_MS_CONFIG);
-        pollTime = Duration.ofMillis(config.getLong(StreamsConfig.POLL_MS_CONFIG));
+
+        final Map<String, Object> consumerProps = config.getGlobalConsumerConfigs("dummy");
+        // need to add mandatory configs; otherwise `QuietConsumerConfig` throws
+        consumerProps.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class);
+        consumerProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class);
+        final int requestTimeoutMs = new ClientUtils.QuietConsumerConfig(consumerProps)
+            .getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG);
+        taskTimeoutMs = config.getLong(StreamsConfig.TASK_TIMEOUT_MS_CONFIG);
+        requestTimeoutPlusTaskTimeout =
+            Duration.ofMillis(requestTimeoutMs + taskTimeoutMs);
     }
 
     @Override
@@ -140,10 +153,13 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
         // make sure each topic-partition from checkpointFileCache is associated with a global state store
         checkpointFileCache.keySet().forEach(tp -> {
             if (!changelogTopics.contains(tp.topic())) {
-                log.error("Encountered a topic-partition in the global checkpoint file not associated with any global" +
-                    " state store, topic-partition: {}, checkpoint file: {}. If this topic-partition is no longer valid," +
-                    " an application reset and state store directory cleanup will be required.",
-                    tp.topic(), checkpointFile.toString());
+                log.error(
+                    "Encountered a topic-partition in the global checkpoint file not associated with any global" +
+                        " state store, topic-partition: {}, checkpoint file: {}. If this topic-partition is no longer valid," +
+                        " an application reset and state store directory cleanup will be required.",
+                    tp.topic(),
+                    checkpointFile.toString()
+                );
                 try {
                     stateDirectory.unlockGlobalState();
                 } catch (final IOException e) {
@@ -184,32 +200,15 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
 
         log.info("Restoring state for global store {}", store.name());
         final List<TopicPartition> topicPartitions = topicPartitionsForStore(store);
-        Map<TopicPartition, Long> highWatermarks = null;
 
-        int attempts = 0;
-        while (highWatermarks == null) {
-            try {
-                highWatermarks = globalConsumer.endOffsets(topicPartitions);
-            } catch (final TimeoutException retryableException) {
-                if (++attempts > retries) {
-                    log.error("Failed to get end offsets for topic partitions of global store {} after {} retry attempts. " +
-                        "You can increase the number of retries via configuration parameter `retries`.",
-                        store.name(),
-                        retries,
-                        retryableException);
-                    throw new StreamsException(String.format("Failed to get end offsets for topic partitions of global store %s after %d retry attempts. " +
-                            "You can increase the number of retries via configuration parameter `retries`.", store.name(), retries),
-                        retryableException);
-                }
-                log.debug("Failed to get end offsets for partitions {}, backing off for {} ms to retry (attempt {} of {})",
-                    topicPartitions,
-                    retryBackoffMs,
-                    attempts,
-                    retries,
-                    retryableException);
-                Utils.sleep(retryBackoffMs);
-            }
-        }
+        final Map<TopicPartition, Long> highWatermarks = retryUntilSuccessOrThrowOnTaskTimeout(
+            () -> globalConsumer.endOffsets(topicPartitions),
+            String.format(
+                "Failed to get offsets for partitions %s. The broker may be transiently unavailable at the moment.",
+                topicPartitions
+            )
+        );
+
         try {
             restoreState(
                 stateRestoreCallback,
@@ -226,35 +225,14 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
 
     private List<TopicPartition> topicPartitionsForStore(final StateStore store) {
         final String sourceTopic = storeToChangelogTopic.get(store.name());
-        List<PartitionInfo> partitionInfos;
-        int attempts = 0;
-        while (true) {
-            try {
-                partitionInfos = globalConsumer.partitionsFor(sourceTopic);
-                break;
-            } catch (final TimeoutException retryableException) {
-                if (++attempts > retries) {
-                    log.error("Failed to get partitions for topic {} after {} retry attempts due to timeout. " +
-                            "The broker may be transiently unavailable at the moment. " +
-                            "You can increase the number of retries via configuration parameter `retries`.",
-                        sourceTopic,
-                        retries,
-                        retryableException);
-                    throw new StreamsException(String.format("Failed to get partitions for topic %s after %d retry attempts due to timeout. " +
-                        "The broker may be transiently unavailable at the moment. " +
-                        "You can increase the number of retries via configuration parameter `retries`.", sourceTopic, retries),
-                        retryableException);
-                }
-                log.debug("Failed to get partitions for topic {} due to timeout. The broker may be transiently unavailable at the moment. " +
-                        "Backing off for {} ms to retry (attempt {} of {})",
-                    sourceTopic,
-                    retryBackoffMs,
-                    attempts,
-                    retries,
-                    retryableException);
-                Utils.sleep(retryBackoffMs);
-            }
-        }
+
+        final List<PartitionInfo> partitionInfos = retryUntilSuccessOrThrowOnTaskTimeout(
+            () -> globalConsumer.partitionsFor(sourceTopic),
+            String.format(
+                "Failed to get partitions for topic %s. The broker may be transiently unavailable at the moment.",
+                sourceTopic
+            )
+        );
 
         if (partitionInfos == null || partitionInfos.isEmpty()) {
             throw new StreamsException(String.format("There are no partitions available for topic %s when initializing global store %s", sourceTopic, store.name()));
@@ -274,14 +252,22 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
                               final RecordConverter recordConverter) {
         for (final TopicPartition topicPartition : topicPartitions) {
             globalConsumer.assign(Collections.singletonList(topicPartition));
+            long offset;
             final Long checkpoint = checkpointFileCache.get(topicPartition);
             if (checkpoint != null) {
                 globalConsumer.seek(topicPartition, checkpoint);
+                offset = checkpoint;
             } else {
                 globalConsumer.seekToBeginning(Collections.singletonList(topicPartition));
+                offset = retryUntilSuccessOrThrowOnTaskTimeout(
+                    () -> globalConsumer.position(topicPartition),
+                    String.format(
+                        "Failed to get position for partition %s. The broker may be transiently unavailable at the moment.",
+                        topicPartition
+                    )
+                );
             }
 
-            long offset = globalConsumer.position(topicPartition);
             final Long highWatermark = highWatermarks.get(topicPartition);
             final RecordBatchingStateRestoreCallback stateRestoreAdapter =
                 StateRestoreCallbackAdapter.adapt(stateRestoreCallback);
@@ -289,15 +275,51 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
             stateRestoreListener.onRestoreStart(topicPartition, storeName, offset, highWatermark);
             long restoreCount = 0L;
 
-            while (offset < highWatermark) {
-                final ConsumerRecords<byte[], byte[]> records = globalConsumer.poll(pollTime);
+            while (offset < highWatermark) { // when we "fix" this loop (KAFKA-7380 / KAFKA-10317)
+                                             // we should update the `poll()` timeout below
+
+                // we ignore `poll.ms` config during bootstrapping phase and
+                // apply `request.timeout.ms` plus `task.timeout.ms` instead
+                //
+                // the reason is, that `poll.ms` might be too short to give a fetch request a fair chance
+                // to actually complete and we don't want to start `task.timeout.ms` too early
+                //
+                // we also pass `task.timeout.ms` into `poll()` directly right now as it simplifies our own code:
+                // if we don't pass it in, we would just track the timeout ourselves and call `poll()` again
+                // in our own retry loop; by passing the timeout we can reuse the consumer's internal retry loop instead
+                //
+                // note that using `request.timeout.ms` provides a conservative upper bound for the timeout;
+                // this implies that we might start `task.timeout.ms` "delayed" -- however, starting the timeout
+                // delayed is preferable (as it's more robust) than starting it too early
+                //
+                // TODO https://issues.apache.org/jira/browse/KAFKA-10315
+                //   -> do a more precise timeout handling if `poll` would throw an exception if a fetch request fails
+                //      (instead of letting the consumer retry fetch requests silently)
+                //
+                // TODO https://issues.apache.org/jira/browse/KAFKA-10317 and
+                //      https://issues.apache.org/jira/browse/KAFKA-7380
+                //  -> don't pass in `task.timeout.ms` to stay responsive if `KafkaStreams#close` gets called
+                final ConsumerRecords<byte[], byte[]> records = globalConsumer.poll(requestTimeoutPlusTaskTimeout);
+                if (records.isEmpty()) {
+                    // this will always throw
+                    maybeUpdateDeadlineOrThrow(time.milliseconds());
+                }
+
                 final List<ConsumerRecord<byte[], byte[]>> restoreRecords = new ArrayList<>();
                 for (final ConsumerRecord<byte[], byte[]> record : records.records(topicPartition)) {
                     if (record.key() != null) {
                         restoreRecords.add(recordConverter.convert(record));
                     }
                 }
-                offset = globalConsumer.position(topicPartition);
+
+                offset = retryUntilSuccessOrThrowOnTaskTimeout(
+                    () -> globalConsumer.position(topicPartition),
+                    String.format(
+                        "Failed to get position for partition %s. The broker may be transiently unavailable at the moment.",
+                        topicPartition
+                    )
+                );
+
                 stateRestoreAdapter.restoreBatch(restoreRecords);
                 stateRestoreListener.onBatchRestored(topicPartition, storeName, offset, restoreRecords.size());
                 restoreCount += restoreRecords.size();
@@ -307,6 +329,48 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
         }
     }
 
+    private <R> R retryUntilSuccessOrThrowOnTaskTimeout(final Supplier<R> supplier,
+                                                        final String errorMessage) {
+        long deadlineMs = NO_DEADLINE;
+
+        do {
+            try {
+                return supplier.get();
+            } catch (final TimeoutException retriableException) {
+                if (taskTimeoutMs == 0L) {
+                    throw new StreamsException(
+                        String.format(
+                            "Retrying is disabled. You can enable it by setting `%s` to a value larger than zero.",
+                            StreamsConfig.TASK_TIMEOUT_MS_CONFIG
+                        ),
+                        retriableException
+                    );
+                }
+
+                deadlineMs = maybeUpdateDeadlineOrThrow(deadlineMs);
+
+                log.warn(errorMessage, retriableException);
+            }
+        } while (true);
+    }
+
+    private long maybeUpdateDeadlineOrThrow(final long currentDeadlineMs) {
+        final long currentWallClockMs = time.milliseconds();
+
+        if (currentDeadlineMs == NO_DEADLINE) {
+            final long newDeadlineMs = currentWallClockMs + taskTimeoutMs;
+            return newDeadlineMs < 0L ? Long.MAX_VALUE : newDeadlineMs;
+        } else if (currentWallClockMs >= currentDeadlineMs) {
+            throw new TimeoutException(String.format(
+                "Global task did not make progress to restore state within %d ms. Adjust `%s` if needed.",
+                currentWallClockMs - currentDeadlineMs + taskTimeoutMs,
+                StreamsConfig.TASK_TIMEOUT_MS_CONFIG
+            ));
+        }
+
+        return currentDeadlineMs;
+    }
+
     @Override
     public void flush() {
         log.debug("Flushing all global globalStores registered in the state manager");
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
index 236940ca..ab4b57a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
@@ -337,17 +337,20 @@ public class GlobalStreamThread extends Thread {
         try {
             final GlobalStateManager stateMgr = new GlobalStateManagerImpl(
                 logContext,
+                time,
                 topology,
                 globalConsumer,
                 stateDirectory,
                 stateRestoreListener,
-                config);
+                config
+            );
 
             final GlobalProcessorContextImpl globalProcessorContext = new GlobalProcessorContextImpl(
                 config,
                 stateMgr,
                 streamsMetrics,
-                cache);
+                cache
+            );
             stateMgr.setGlobalProcessorContext(globalProcessorContext);
 
             final StateConsumer stateConsumer = new StateConsumer(
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java
index b89d4a27..42a0016 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java
@@ -30,7 +30,6 @@ import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.StreamsException;
-import org.apache.kafka.streams.processor.internals.ClientUtils.QuietAdminClientConfig;
 import org.slf4j.Logger;
 
 import java.util.HashMap;
@@ -64,7 +63,7 @@ public class InternalTopicManager {
 
         replicationFactor = streamsConfig.getInt(StreamsConfig.REPLICATION_FACTOR_CONFIG).shortValue();
         windowChangeLogAdditionalRetention = streamsConfig.getLong(StreamsConfig.WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG);
-        final QuietAdminClientConfig adminConfigs = new QuietAdminClientConfig(streamsConfig);
+        final AdminClientConfig adminConfigs = new ClientUtils.QuietAdminClientConfig(streamsConfig);
         retries = adminConfigs.getInt(AdminClientConfig.RETRIES_CONFIG);
         retryBackOffMs = adminConfigs.getLong(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG);
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
index 2d510a6..1fd83ba 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
@@ -27,7 +27,7 @@ import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsConfig.InternalConfig;
-import org.apache.kafka.streams.internals.QuietStreamsConfig;
+import org.apache.kafka.streams.processor.internals.ClientUtils;
 import org.apache.kafka.streams.processor.internals.InternalTopicManager;
 import org.apache.kafka.streams.processor.internals.StreamsMetadataState;
 import org.apache.kafka.streams.processor.internals.TaskManager;
@@ -57,7 +57,7 @@ public final class AssignorConfiguration {
         // NOTE: If you add a new config to pass through to here, be sure to test it in a real
         // application. Since we filter out some configurations, we may have to explicitly copy
         // them over when we construct the Consumer.
-        streamsConfig = new QuietStreamsConfig(configs);
+        streamsConfig = new ClientUtils.QuietStreamsConfig(configs);
         internalConfigs = configs;
 
         // Setting the logger with the passed in client thread name
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
index d937473..2b630a1 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
@@ -16,7 +16,9 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.clients.consumer.MockConsumer;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.common.PartitionInfo;
@@ -47,7 +49,9 @@ import java.io.File;
 import java.io.IOException;
 import java.io.OutputStream;
 import java.nio.file.Files;
+import java.time.Duration;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -58,10 +62,13 @@ import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import static java.util.Arrays.asList;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.test.MockStateRestoreListener.RESTORE_BATCH;
 import static org.apache.kafka.test.MockStateRestoreListener.RESTORE_END;
 import static org.apache.kafka.test.MockStateRestoreListener.RESTORE_START;
 import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -130,11 +137,13 @@ public class GlobalStateManagerImplTest {
         consumer = new MockConsumer<>(OffsetResetStrategy.NONE);
         stateManager = new GlobalStateManagerImpl(
             new LogContext("test"),
+            time,
             topology,
             consumer,
             stateDirectory,
             stateRestoreListener,
-            streamsConfig);
+            streamsConfig
+        );
         processorContext = new InternalMockProcessorContext(stateDirectory.globalStateDir(), streamsConfig);
         stateManager.setGlobalProcessorContext(processorContext);
         checkpointFile = new File(stateManager.baseDir(), StateManagerUtil.CHECKPOINT_FILE_NAME);
@@ -340,7 +349,7 @@ public class GlobalStateManagerImplTest {
     }
 
     @Test
-    public void shouldRestoreRecordsFromCheckpointToHighwatermark() throws IOException {
+    public void shouldRestoreRecordsFromCheckpointToHighWatermark() throws IOException {
         initializeConsumer(5, 5, t1);
 
         final OffsetCheckpoint offsetCheckpoint = new OffsetCheckpoint(new File(stateManager.baseDir(),
@@ -576,6 +585,7 @@ public class GlobalStateManagerImplTest {
     public void shouldThrowLockExceptionIfIOExceptionCaughtWhenTryingToLockStateDir() {
         stateManager = new GlobalStateManagerImpl(
             new LogContext("mock"),
+            time,
             topology,
             consumer,
             new StateDirectory(streamsConfig, time, true) {
@@ -596,72 +606,602 @@ public class GlobalStateManagerImplTest {
         }
     }
 
-    @SuppressWarnings("deprecation") // TODO revisit in follow up PR
     @Test
-    public void shouldRetryWhenEndOffsetsThrowsTimeoutException() {
-        final int retries = 2;
+    public void shouldNotRetryWhenEndOffsetsThrowsTimeoutExceptionAndTaskTimeoutIsZero() {
         final AtomicInteger numberOfCalls = new AtomicInteger(0);
         consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
             @Override
-            public synchronized Map<TopicPartition, Long> endOffsets(final Collection<org.apache.kafka.common.TopicPartition> partitions) {
+            public synchronized Map<TopicPartition, Long> endOffsets(final Collection<TopicPartition> partitions) {
                 numberOfCalls.incrementAndGet();
-                throw new TimeoutException();
+                throw new TimeoutException("KABOOM!");
             }
         };
-        streamsConfig = new StreamsConfig(new Properties() {
-            {
-                put(StreamsConfig.APPLICATION_ID_CONFIG, "appId");
-                put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234");
-                put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath());
-                put(StreamsConfig.RETRIES_CONFIG, retries);
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 0L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final StreamsException expected = assertThrows(
+            StreamsException.class,
+            () -> stateManager.initialize()
+        );
+        final Throwable cause = expected.getCause();
+        assertThat(cause, instanceOf(TimeoutException.class));
+        assertThat(cause.getMessage(), equalTo("KABOOM!"));
+
+        assertEquals(numberOfCalls.get(), 1);
+    }
+
+    @Test
+    public void shouldRetryAtLeastOnceWhenEndOffsetsThrowsTimeoutException() {
+        final AtomicInteger numberOfCalls = new AtomicInteger(0);
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public synchronized Map<TopicPartition, Long> endOffsets(final Collection<TopicPartition> partitions) {
+                time.sleep(100L);
+                numberOfCalls.incrementAndGet();
+                throw new TimeoutException("KABOOM!");
             }
-        });
+        };
+        initializeConsumer(0, 0, t1, t2, t3, t4);
 
-        try {
-            new GlobalStateManagerImpl(
-                new LogContext("mock"),
-                topology,
-                consumer,
-                stateDirectory,
-                stateRestoreListener,
-                streamsConfig);
-        } catch (final StreamsException expected) {
-            assertEquals(numberOfCalls.get(), retries);
-        }
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final TimeoutException expected = assertThrows(
+            TimeoutException.class,
+            () -> stateManager.initialize()
+        );
+        assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 100 ms. Adjust `task.timeout.ms` if needed."));
+
+        assertEquals(numberOfCalls.get(), 2);
     }
 
-    @SuppressWarnings("deprecation") // TODO revisit in follow up PR
     @Test
-    public void shouldRetryWhenPartitionsForThrowsTimeoutException() {
-        final int retries = 2;
+    public void shouldRetryWhenEndOffsetsThrowsTimeoutExceptionUntilTaskTimeoutExpired() {
         final AtomicInteger numberOfCalls = new AtomicInteger(0);
         consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
             @Override
-            public synchronized List<PartitionInfo> partitionsFor(final String topic) {
+            public synchronized Map<TopicPartition, Long> endOffsets(final Collection<TopicPartition> partitions) {
+                time.sleep(100L);
                 numberOfCalls.incrementAndGet();
-                throw new TimeoutException();
+                throw new TimeoutException("KABOOM!");
             }
         };
-        streamsConfig = new StreamsConfig(new Properties() {
-            {
-                put(StreamsConfig.APPLICATION_ID_CONFIG, "appId");
-                put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234");
-                put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath());
-                put(StreamsConfig.RETRIES_CONFIG, retries);
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1000L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final TimeoutException expected = assertThrows(
+            TimeoutException.class,
+            () -> stateManager.initialize()
+        );
+        assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 1000 ms. Adjust `task.timeout.ms` if needed."));
+
+        assertEquals(numberOfCalls.get(), 11);
+    }
+
+    @Test
+    public void shouldNotFailOnSlowProgressWhenEndOffsetsThrowsTimeoutException() {
+        final AtomicInteger numberOfCalls = new AtomicInteger(0);
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public synchronized Map<TopicPartition, Long> endOffsets(final Collection<TopicPartition> partitions) {
+                time.sleep(1L);
+                if (numberOfCalls.incrementAndGet() % 3 == 0) {
+                    return super.endOffsets(partitions);
+                }
+                throw new TimeoutException("KABOOM!");
             }
-        });
 
-        try {
-            new GlobalStateManagerImpl(
-                new LogContext("mock"),
-                topology,
-                consumer,
-                stateDirectory,
-                stateRestoreListener,
-                streamsConfig);
-        } catch (final StreamsException expected) {
-            assertEquals(numberOfCalls.get(), retries);
-        }
+            @Override
+            public synchronized long position(final TopicPartition partition) {
+                return numberOfCalls.incrementAndGet();
+            }
+        };
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 10L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        stateManager.initialize();
+    }
+
+    @Test
+    public void shouldNotRetryWhenPartitionsForThrowsTimeoutExceptionAndTaskTimeoutIsZero() {
+        final AtomicInteger numberOfCalls = new AtomicInteger(0);
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public List<PartitionInfo> partitionsFor(final String topic) {
+                numberOfCalls.incrementAndGet();
+                throw new TimeoutException("KABOOM!");
+            }
+        };
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 0L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final StreamsException expected = assertThrows(
+            StreamsException.class,
+            () -> stateManager.initialize()
+        );
+        final Throwable cause = expected.getCause();
+        assertThat(cause, instanceOf(TimeoutException.class));
+        assertThat(cause.getMessage(), equalTo("KABOOM!"));
+
+        assertEquals(numberOfCalls.get(), 1);
+    }
+
+    @Test
+    public void shouldRetryAtLeastOnceWhenPartitionsForThrowsTimeoutException() {
+        final AtomicInteger numberOfCalls = new AtomicInteger(0);
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public List<PartitionInfo> partitionsFor(final String topic) {
+                time.sleep(100L);
+                numberOfCalls.incrementAndGet();
+                throw new TimeoutException("KABOOM!");
+            }
+        };
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final TimeoutException expected = assertThrows(
+            TimeoutException.class,
+            () -> stateManager.initialize()
+        );
+        assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 100 ms. Adjust `task.timeout.ms` if needed."));
+
+        assertEquals(numberOfCalls.get(), 2);
+    }
+
+    @Test
+    public void shouldRetryWhenPartitionsForThrowsTimeoutExceptionUntilTaskTimeoutExpires() {
+        final AtomicInteger numberOfCalls = new AtomicInteger(0);
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public List<PartitionInfo> partitionsFor(final String topic) {
+                time.sleep(100L);
+                numberOfCalls.incrementAndGet();
+                throw new TimeoutException("KABOOM!");
+            }
+        };
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1000L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final TimeoutException expected = assertThrows(
+            TimeoutException.class,
+            () -> stateManager.initialize()
+        );
+        assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 1000 ms. Adjust `task.timeout.ms` if needed."));
+
+        assertEquals(numberOfCalls.get(), 11);
+    }
+
+    @Test
+    public void shouldNotFailOnSlowProgressWhenPartitionForThrowsTimeoutException() {
+        final AtomicInteger numberOfCalls = new AtomicInteger(0);
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public List<PartitionInfo> partitionsFor(final String topic) {
+                time.sleep(1L);
+                if (numberOfCalls.incrementAndGet() % 3 == 0) {
+                    return super.partitionsFor(topic);
+                }
+                throw new TimeoutException("KABOOM!");
+            }
+
+            @Override
+            public synchronized long position(final TopicPartition partition) {
+                return numberOfCalls.incrementAndGet();
+            }
+        };
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 10L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        stateManager.initialize();
+    }
+
+    @Test
+    public void shouldNotRetryWhenPositionThrowsTimeoutExceptionAndTaskTimeoutIsZero() {
+        final AtomicInteger numberOfCalls = new AtomicInteger(0);
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public synchronized long position(final TopicPartition partition) {
+                numberOfCalls.incrementAndGet();
+                throw new TimeoutException("KABOOM!");
+            }
+        };
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 0L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final StreamsException expected = assertThrows(
+            StreamsException.class,
+            () -> stateManager.initialize()
+        );
+        final Throwable cause = expected.getCause();
+        assertThat(cause, instanceOf(TimeoutException.class));
+        assertThat(cause.getMessage(), equalTo("KABOOM!"));
+
+        assertEquals(numberOfCalls.get(), 1);
+    }
+
+    @Test
+    public void shouldRetryAtLeastOnceWhenPositionThrowsTimeoutException() {
+        final AtomicInteger numberOfCalls = new AtomicInteger(0);
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public synchronized long position(final TopicPartition partition) {
+                time.sleep(100L);
+                numberOfCalls.incrementAndGet();
+                throw new TimeoutException("KABOOM!");
+            }
+        };
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final TimeoutException expected = assertThrows(
+            TimeoutException.class,
+            () -> stateManager.initialize()
+        );
+        assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 100 ms. Adjust `task.timeout.ms` if needed."));
+
+        assertEquals(numberOfCalls.get(), 2);
+    }
+
+    @Test
+    public void shouldRetryWhenPositionThrowsTimeoutExceptionUntilTaskTimeoutExpired() {
+        final AtomicInteger numberOfCalls = new AtomicInteger(0);
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public synchronized long position(final TopicPartition partition) {
+                time.sleep(100L);
+                numberOfCalls.incrementAndGet();
+                throw new TimeoutException("KABOOM!");
+            }
+        };
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1000L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final TimeoutException expected = assertThrows(
+            TimeoutException.class,
+            () -> stateManager.initialize()
+        );
+        assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 1000 ms. Adjust `task.timeout.ms` if needed."));
+
+        assertEquals(numberOfCalls.get(), 11);
+    }
+
+    @Test
+    public void shouldNotFailOnSlowProgressWhenPositionThrowsTimeoutException() {
+        final AtomicInteger numberOfCalls = new AtomicInteger(0);
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public synchronized long position(final TopicPartition partition) {
+                time.sleep(1L);
+                if (numberOfCalls.incrementAndGet() % 3 == 0) {
+                    return numberOfCalls.incrementAndGet();
+                }
+                throw new TimeoutException("KABOOM!");
+            }
+        };
+        initializeConsumer(0, 0, t1, t2, t3, t4);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 10L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        stateManager.initialize();
+    }
+
+    @Test
+    public void shouldUseRequestTimeoutPlusTaskTimeoutInPollDuringRestoreAndFailIfNoDataReturned() {
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public synchronized ConsumerRecords<byte[], byte[]> poll(final Duration timeout) {
+                time.sleep(timeout.toMillis());
+                return super.poll(timeout);
+            }
+        };
+
+        final HashMap<TopicPartition, Long> startOffsets = new HashMap<>();
+        startOffsets.put(t1, 1L);
+        final HashMap<TopicPartition, Long> endOffsets = new HashMap<>();
+        endOffsets.put(t1, 3L);
+        consumer.updatePartitions(t1.topic(), Collections.singletonList(new PartitionInfo(t1.topic(), t1.partition(), null, null, null)));
+        consumer.assign(Collections.singletonList(t1));
+        consumer.updateBeginningOffsets(startOffsets);
+        consumer.updateEndOffsets(endOffsets);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.POLL_MS_CONFIG, 5L),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 10L),
+            mkEntry(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG, 100)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final long startTime = time.milliseconds();
+        final TimeoutException exception = assertThrows(
+            TimeoutException.class,
+            () -> stateManager.initialize()
+        );
+        assertThat(
+            exception.getMessage(),
+            equalTo("Global task did not make progress to restore state within 10 ms. Adjust `task.timeout.ms` if needed.")
+        );
+
+        assertThat(time.milliseconds() - startTime, equalTo(110L));
+
+    }
+
+    @Test
+    public void shouldTimeoutWhenNoProgressDuringRestore() {
+        consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
+            @Override
+            public synchronized ConsumerRecords<byte[], byte[]> poll(final Duration timeout) {
+                time.sleep(1L);
+                return super.poll(timeout);
+            }
+        };
+
+        final HashMap<TopicPartition, Long> startOffsets = new HashMap<>();
+        startOffsets.put(t1, 1L);
+        final HashMap<TopicPartition, Long> endOffsets = new HashMap<>();
+        endOffsets.put(t1, 3L);
+        consumer.updatePartitions(t1.topic(), Collections.singletonList(new PartitionInfo(t1.topic(), t1.partition(), null, null, null)));
+        consumer.assign(Collections.singletonList(t1));
+        consumer.updateBeginningOffsets(startOffsets);
+        consumer.updateEndOffsets(endOffsets);
+
+        streamsConfig = new StreamsConfig(mkMap(
+            mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
+            mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"),
+            mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()),
+            mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 5L)
+        ));
+
+        stateManager = new GlobalStateManagerImpl(
+            new LogContext("mock"),
+            time,
+            topology,
+            consumer,
+            stateDirectory,
+            stateRestoreListener,
+            streamsConfig
+        );
+        processorContext.setStateManger(stateManager);
+        stateManager.setGlobalProcessorContext(processorContext);
+
+        final long startTime = time.milliseconds();
+
+        final TimeoutException exception = assertThrows(
+            TimeoutException.class,
+            () -> stateManager.initialize()
+        );
+        assertThat(
+            exception.getMessage(),
+            equalTo("Global task did not make progress to restore state within 5 ms. Adjust `task.timeout.ms` if needed.")
+        );
+        assertThat(time.milliseconds() - startTime, equalTo(1L));
     }
 
     private void writeCorruptCheckpoint() throws IOException {
@@ -671,19 +1211,21 @@ public class GlobalStateManagerImplTest {
         }
     }
 
-    private void initializeConsumer(final long numRecords, final long startOffset, final TopicPartition topicPartition) {
-        final HashMap<TopicPartition, Long> startOffsets = new HashMap<>();
-        startOffsets.put(topicPartition, startOffset);
-        final HashMap<TopicPartition, Long> endOffsets = new HashMap<>();
-        endOffsets.put(topicPartition, startOffset + numRecords);
-        consumer.updatePartitions(topicPartition.topic(), Collections.singletonList(new PartitionInfo(topicPartition.topic(), topicPartition.partition(), null, null, null)));
-        consumer.assign(Collections.singletonList(topicPartition));
+    private void initializeConsumer(final long numRecords, final long startOffset, final TopicPartition... topicPartitions) {
+        consumer.assign(Arrays.asList(topicPartitions));
+
+        final Map<TopicPartition, Long> startOffsets = new HashMap<>();
+        final Map<TopicPartition, Long> endOffsets = new HashMap<>();
+        for (final TopicPartition topicPartition : topicPartitions) {
+            startOffsets.put(topicPartition, startOffset);
+            endOffsets.put(topicPartition, startOffset + numRecords);
+            consumer.updatePartitions(topicPartition.topic(), Collections.singletonList(new PartitionInfo(topicPartition.topic(), topicPartition.partition(), null, null, null)));
+            for (int i = 0; i < numRecords; i++) {
+                consumer.addRecord(new ConsumerRecord<>(topicPartition.topic(), topicPartition.partition(), startOffset + i, "key".getBytes(), "value".getBytes()));
+            }
+        }
         consumer.updateEndOffsets(endOffsets);
         consumer.updateBeginningOffsets(startOffsets);
-
-        for (int i = 0; i < numRecords; i++) {
-            consumer.addRecord(new ConsumerRecord<>(topicPartition.topic(), topicPartition.partition(), startOffset + i, "key".getBytes(), "value".getBytes()));
-        }
     }
 
     private Map<TopicPartition, Long> writeCheckpoint() throws IOException {
diff --git a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
index bf2bcf0..4c782a4 100644
--- a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
@@ -62,6 +62,7 @@ public class InternalMockProcessorContext
     extends AbstractProcessorContext
     implements RecordCollector.Supplier {
 
+    private StateManager stateManager = new StateManagerStub();
     private final File stateDir;
     private final RecordCollector.Supplier recordCollectorSupplier;
     private final Map<String, StateStore> storeMap = new LinkedHashMap<>();
@@ -197,7 +198,11 @@ public class InternalMockProcessorContext
 
     @Override
     protected StateManager stateManager() {
-        return new StateManagerStub();
+        return stateManager;
+    }
+
+    public void setStateManger(final StateManager stateManger) {
+        this.stateManager = stateManger;
     }
 
     @Override
@@ -245,6 +250,7 @@ public class InternalMockProcessorContext
                          final StateRestoreCallback func) {
         storeMap.put(store.name(), store);
         restoreFuncs.put(store.name(), func);
+        stateManager().registerStore(store, func);
     }
 
     @Override
diff --git a/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java b/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java
index dbdd0b4..dd78c52 100644
--- a/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java
+++ b/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java
@@ -78,6 +78,7 @@ public class NoOpReadOnlyStore<K, V> implements ReadOnlyKeyValueStore<K, V>, Sta
             new File(context.stateDir() + File.separator + name).mkdir();
         }
         this.initialized = true;
+        context.register(root, (k, v) -> { });
     }
 
     @Override
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index 225fd74..1c52d9e 100644
--- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -43,7 +43,6 @@ import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler;
 import org.apache.kafka.streams.errors.TopologyException;
 import org.apache.kafka.streams.internals.KeyValueStoreFacade;
-import org.apache.kafka.streams.internals.QuietStreamsConfig;
 import org.apache.kafka.streams.internals.WindowStoreFacade;
 import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.processor.PunctuationType;
@@ -52,6 +51,7 @@ import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.ChangelogRegister;
+import org.apache.kafka.streams.processor.internals.ClientUtils;
 import org.apache.kafka.streams.processor.internals.GlobalProcessorContextImpl;
 import org.apache.kafka.streams.processor.internals.GlobalStateManager;
 import org.apache.kafka.streams.processor.internals.GlobalStateManagerImpl;
@@ -300,7 +300,7 @@ public class TopologyTestDriver implements Closeable {
     private TopologyTestDriver(final InternalTopologyBuilder builder,
                                final Properties config,
                                final long initialWallClockTimeMs) {
-        final StreamsConfig streamsConfig = new QuietStreamsConfig(config);
+        final StreamsConfig streamsConfig = new ClientUtils.QuietStreamsConfig(config);
         logIfTaskIdleEnabled(streamsConfig);
 
         logContext = new LogContext("topology-test-driver ");
@@ -350,7 +350,7 @@ public class TopologyTestDriver implements Closeable {
             logContext
         );
 
-        setupGlobalTask(streamsConfig, streamsMetrics, cache);
+        setupGlobalTask(mockWallClockTime, streamsConfig, streamsMetrics, cache);
         setupTask(streamsConfig, streamsMetrics, cache);
     }
 
@@ -407,7 +407,8 @@ public class TopologyTestDriver implements Closeable {
         stateDirectory = new StateDirectory(streamsConfig, mockWallClockTime, createStateDirectory);
     }
 
-    private void setupGlobalTask(final StreamsConfig streamsConfig,
+    private void setupGlobalTask(final Time mockWallClockTime,
+                                 final StreamsConfig streamsConfig,
                                  final StreamsMetricsImpl streamsMetrics,
                                  final ThreadCache cache) {
         if (globalTopology != null) {
@@ -424,6 +425,7 @@ public class TopologyTestDriver implements Closeable {
 
             globalStateManager = new GlobalStateManagerImpl(
                 logContext,
+                mockWallClockTime,
                 globalTopology,
                 globalConsumer,
                 stateDirectory,
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/MockProcessorContext.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/MockProcessorContext.java
index b16eb32..fbd5ca8 100644
--- a/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/MockProcessorContext.java
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/MockProcessorContext.java
@@ -27,9 +27,9 @@ import org.apache.kafka.streams.StreamsMetrics;
 import org.apache.kafka.streams.Topology;
 import org.apache.kafka.streams.TopologyTestDriver;
 import org.apache.kafka.streams.internals.ApiUtils;
-import org.apache.kafka.streams.internals.QuietStreamsConfig;
 import org.apache.kafka.streams.kstream.Transformer;
 import org.apache.kafka.streams.kstream.ValueTransformer;
+import org.apache.kafka.streams.processor.internals.ClientUtils;
 import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
@@ -217,7 +217,7 @@ public class MockProcessorContext implements ProcessorContext, RecordCollector.S
      */
     @SuppressWarnings({"WeakerAccess", "unused"})
     public MockProcessorContext(final Properties config, final TaskId taskId, final File stateDir) {
-        final StreamsConfig streamsConfig = new QuietStreamsConfig(config);
+        final StreamsConfig streamsConfig = new ClientUtils.QuietStreamsConfig(config);
         this.taskId = taskId;
         this.config = streamsConfig;
         this.stateDir = stateDir;