You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ju...@apache.org on 2023/02/24 23:29:43 UTC

[kafka] branch trunk updated: [KAFKA-14685] Refactor logic to handle OFFSET_MOVED_TO_TIERED_STORAGE error (#13206)

This is an automated email from the ASF dual-hosted git repository.

junrao pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 8d32a0f2463 [KAFKA-14685] Refactor logic to handle OFFSET_MOVED_TO_TIERED_STORAGE error (#13206)
8d32a0f2463 is described below

commit 8d32a0f2463eaa4c0b669b8a957a7abc066c066f
Author: Matthew Wong <ma...@gmail.com>
AuthorDate: Fri Feb 24 15:29:35 2023 -0800

    [KAFKA-14685] Refactor logic to handle OFFSET_MOVED_TO_TIERED_STORAGE error (#13206)
    
    Reviewers: Rittika Adhikari <ri...@gmail.com>, Luke Chen <sh...@gmail.com>, Satish Duggana <sa...@apache.org>, Alexandre Dupriez <al...@gmail.com>, Jun Rao <ju...@gmail.com>
---
 checkstyle/import-control-core.xml                 |   6 +-
 .../ReplicaAlterLogDirsTierStateMachine.java       |  41 ++++
 .../server/ReplicaFetcherTierStateMachine.java     | 266 +++++++++++++++++++++
 .../main/java/kafka/server/TierStateMachine.java   |  58 +++++
 .../scala/kafka/server/AbstractFetcherThread.scala | 111 +++------
 .../kafka/server/ReplicaAlterLogDirsThread.scala   |  11 +-
 .../scala/kafka/server/ReplicaFetcherThread.scala  | 154 +-----------
 .../kafka/server/AbstractFetcherManagerTest.scala  |  17 +-
 .../kafka/server/AbstractFetcherThreadTest.scala   | 210 +++++++++++-----
 9 files changed, 558 insertions(+), 316 deletions(-)

diff --git a/checkstyle/import-control-core.xml b/checkstyle/import-control-core.xml
index 1b052b62d78..0d5935f9f2b 100644
--- a/checkstyle/import-control-core.xml
+++ b/checkstyle/import-control-core.xml
@@ -75,10 +75,8 @@
   </subpackage>
 
   <subpackage name="server">
-    <subpackage name="builders">
-      <allow pkg="kafka" />
-      <allow pkg="org.apache.kafka" />
-    </subpackage>
+    <allow pkg="kafka" />
+    <allow pkg="org.apache.kafka" />
   </subpackage>
 
   <subpackage name="test">
diff --git a/core/src/main/java/kafka/server/ReplicaAlterLogDirsTierStateMachine.java b/core/src/main/java/kafka/server/ReplicaAlterLogDirsTierStateMachine.java
new file mode 100644
index 00000000000..8561fae0199
--- /dev/null
+++ b/core/src/main/java/kafka/server/ReplicaAlterLogDirsTierStateMachine.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.server;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.message.FetchResponseData.PartitionData;
+
+import java.util.Optional;
+
+/**
+ The replica alter log dirs tier state machine is unsupported but is provided to the ReplicaAlterLogDirsThread.
+ */
+public class ReplicaAlterLogDirsTierStateMachine implements TierStateMachine {
+
+    public PartitionFetchState start(TopicPartition topicPartition,
+                                     PartitionFetchState currentFetchState,
+                                     PartitionData fetchPartitionData) throws Exception {
+        // JBOD is not supported with tiered storage.
+        throw new UnsupportedOperationException("Building remote log aux state is not supported in ReplicaAlterLogDirsThread.");
+    }
+
+    public Optional<PartitionFetchState> maybeAdvanceState(TopicPartition topicPartition,
+                                                           PartitionFetchState currentFetchState) {
+        return Optional.empty();
+    }
+}
diff --git a/core/src/main/java/kafka/server/ReplicaFetcherTierStateMachine.java b/core/src/main/java/kafka/server/ReplicaFetcherTierStateMachine.java
new file mode 100644
index 00000000000..7cebaae8fe6
--- /dev/null
+++ b/core/src/main/java/kafka/server/ReplicaFetcherTierStateMachine.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.server;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.StandardCopyOption;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import kafka.cluster.Partition;
+import kafka.log.LeaderOffsetIncremented$;
+import kafka.log.UnifiedLog;
+import kafka.log.remote.RemoteLogManager;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.message.FetchResponseData.PartitionData;
+import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset;
+import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.server.common.CheckpointFile;
+import org.apache.kafka.server.common.OffsetAndEpoch;
+import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata;
+import org.apache.kafka.server.log.remote.storage.RemoteStorageException;
+import org.apache.kafka.server.log.remote.storage.RemoteStorageManager;
+import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile;
+import org.apache.kafka.storage.internals.log.EpochEntry;
+import org.apache.kafka.storage.internals.log.LogFileUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import scala.Option;
+import scala.collection.JavaConverters;
+
+/**
+ The replica fetcher tier state machine follows a state machine progression.
+
+ Currently, the tier state machine follows a synchronous execution, and we only need to start the machine.
+ There is no need to advance the state.
+
+ When started, the tier state machine will fetch the local log start offset of the
+ leader and then build the follower's remote log aux state until the leader's
+ local log start offset.
+ */
+public class ReplicaFetcherTierStateMachine implements TierStateMachine {
+    private static final Logger log = LoggerFactory.getLogger(ReplicaFetcherTierStateMachine.class);
+
+    private LeaderEndPoint leader;
+    private ReplicaManager replicaMgr;
+
+    public ReplicaFetcherTierStateMachine(LeaderEndPoint leader,
+                                          ReplicaManager replicaMgr) {
+        this.leader = leader;
+        this.replicaMgr = replicaMgr;
+    }
+
+
+    /**
+     * Start the tier state machine for the provided topic partition. Currently, this start method will build the
+     * entire remote aux log state synchronously.
+     *
+     * @param topicPartition the topic partition
+     * @param currentFetchState the current PartitionFetchState which will
+     *                          be used to derive the return value
+     * @param fetchPartitionData the data from the fetch response that returned the offset moved to tiered storage error
+     *
+     * @return the new PartitionFetchState after the successful start of the
+     *         tier state machine
+     */
+    public PartitionFetchState start(TopicPartition topicPartition,
+                                     PartitionFetchState currentFetchState,
+                                     PartitionData fetchPartitionData) throws Exception {
+
+        OffsetAndEpoch epochAndLeaderLocalStartOffset = leader.fetchEarliestLocalOffset(topicPartition, currentFetchState.currentLeaderEpoch());
+        int epoch = epochAndLeaderLocalStartOffset.leaderEpoch();
+        long leaderLocalStartOffset = epochAndLeaderLocalStartOffset.offset();
+
+        long offsetToFetch = buildRemoteLogAuxState(topicPartition, currentFetchState.currentLeaderEpoch(), leaderLocalStartOffset, epoch, fetchPartitionData.logStartOffset());
+
+        OffsetAndEpoch fetchLatestOffsetResult = leader.fetchLatestOffset(topicPartition, currentFetchState.currentLeaderEpoch());
+        long leaderEndOffset = fetchLatestOffsetResult.offset();
+
+        long initialLag = leaderEndOffset - offsetToFetch;
+
+        return PartitionFetchState.apply(currentFetchState.topicId(), offsetToFetch, Option.apply(initialLag), currentFetchState.currentLeaderEpoch(),
+                Fetching$.MODULE$, replicaMgr.localLogOrException(topicPartition).latestEpoch());
+    }
+
+    /**
+     * This is currently a no-op but will be used for implementing async tiering logic in KAFKA-13560.
+     *
+     * @param topicPartition the topic partition
+     * @param currentFetchState the current PartitionFetchState which will
+     *                          be used to derive the return value
+     *
+     * @return the original PartitionFetchState
+     */
+    public Optional<PartitionFetchState> maybeAdvanceState(TopicPartition topicPartition,
+                                                           PartitionFetchState currentFetchState) {
+        // No-op for now
+        return Optional.of(currentFetchState);
+    }
+
+    private EpochEndOffset fetchEarlierEpochEndOffset(Integer epoch,
+                                                      TopicPartition partition,
+                                                      Integer currentLeaderEpoch) {
+        int previousEpoch = epoch - 1;
+
+        // Find the end-offset for the epoch earlier to the given epoch from the leader
+        Map<TopicPartition, OffsetForLeaderPartition> partitionsWithEpochs = new HashMap<>();
+        partitionsWithEpochs.put(partition, new OffsetForLeaderPartition().setPartition(partition.partition()).setCurrentLeaderEpoch(currentLeaderEpoch).setLeaderEpoch(previousEpoch));
+        Option<EpochEndOffset> maybeEpochEndOffset = leader.fetchEpochEndOffsets(JavaConverters.mapAsScalaMap(partitionsWithEpochs)).get(partition);
+        if (maybeEpochEndOffset.isEmpty()) {
+            throw new KafkaException("No response received for partition: " + partition);
+        }
+
+        EpochEndOffset epochEndOffset = maybeEpochEndOffset.get();
+        if (epochEndOffset.errorCode() != Errors.NONE.code()) {
+            throw Errors.forCode(epochEndOffset.errorCode()).exception();
+        }
+
+        return epochEndOffset;
+    }
+
+    private List<EpochEntry> readLeaderEpochCheckpoint(RemoteLogManager rlm,
+                                                       RemoteLogSegmentMetadata remoteLogSegmentMetadata) throws IOException, RemoteStorageException {
+        InputStream inputStream = rlm.storageManager().fetchIndex(remoteLogSegmentMetadata, RemoteStorageManager.IndexType.LEADER_EPOCH);
+        try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
+            CheckpointFile.CheckpointReadBuffer<EpochEntry> readBuffer = new CheckpointFile.CheckpointReadBuffer<>("", bufferedReader, 0, LeaderEpochCheckpointFile.FORMATTER);
+            return readBuffer.read();
+        }
+    }
+
+    private void buildProducerSnapshotFile(File snapshotFile,
+                                           RemoteLogSegmentMetadata remoteLogSegmentMetadata,
+                                           RemoteLogManager rlm) throws IOException, RemoteStorageException {
+        File tmpSnapshotFile = new File(snapshotFile.getAbsolutePath() + ".tmp");
+        // Copy it to snapshot file in atomic manner.
+        Files.copy(rlm.storageManager().fetchIndex(remoteLogSegmentMetadata, RemoteStorageManager.IndexType.PRODUCER_SNAPSHOT),
+                tmpSnapshotFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
+        Utils.atomicMoveWithFallback(tmpSnapshotFile.toPath(), snapshotFile.toPath(), false);
+    }
+
+    /**
+     * It tries to build the required state for this partition from leader and remote storage so that it can start
+     * fetching records from the leader. The return value is the next offset to fetch from the leader, which is the
+     * next offset following the end offset of the remote log portion.
+     */
+    private Long buildRemoteLogAuxState(TopicPartition topicPartition,
+                                        Integer currentLeaderEpoch,
+                                        Long leaderLocalLogStartOffset,
+                                        Integer epochForLeaderLocalLogStartOffset,
+                                        Long leaderLogStartOffset) throws IOException, RemoteStorageException {
+
+        UnifiedLog unifiedLog = replicaMgr.localLogOrException(topicPartition);
+
+        long nextOffset;
+
+        if (unifiedLog.remoteStorageSystemEnable() && unifiedLog.config().remoteLogConfig.remoteStorageEnable) {
+            if (replicaMgr.remoteLogManager().isEmpty()) throw new IllegalStateException("RemoteLogManager is not yet instantiated");
+
+            RemoteLogManager rlm = replicaMgr.remoteLogManager().get();
+
+            // Find the respective leader epoch for (leaderLocalLogStartOffset - 1). We need to build the leader epoch cache
+            // until that offset
+            long previousOffsetToLeaderLocalLogStartOffset = leaderLocalLogStartOffset - 1;
+            int targetEpoch;
+            // If the existing epoch is 0, no need to fetch from earlier epoch as the desired offset(leaderLogStartOffset - 1)
+            // will have the same epoch.
+            if (epochForLeaderLocalLogStartOffset == 0) {
+                targetEpoch = epochForLeaderLocalLogStartOffset;
+            } else {
+                // Fetch the earlier epoch/end-offset(exclusive) from the leader.
+                EpochEndOffset earlierEpochEndOffset = fetchEarlierEpochEndOffset(epochForLeaderLocalLogStartOffset, topicPartition, currentLeaderEpoch);
+                // Check if the target offset lies within the range of earlier epoch. Here, epoch's end-offset is exclusive.
+                if (earlierEpochEndOffset.endOffset() > previousOffsetToLeaderLocalLogStartOffset) {
+                    // Always use the leader epoch from returned earlierEpochEndOffset.
+                    // This gives the respective leader epoch, that will handle any gaps in epochs.
+                    // For ex, leader epoch cache contains:
+                    // leader-epoch   start-offset
+                    //  0               20
+                    //  1               85
+                    //  <2> - gap no messages were appended in this leader epoch.
+                    //  3               90
+                    //  4               98
+                    // There is a gap in leader epoch. For leaderLocalLogStartOffset as 90, leader-epoch is 3.
+                    // fetchEarlierEpochEndOffset(2) will return leader-epoch as 1, end-offset as 90.
+                    // So, for offset 89, we should return leader epoch as 1 like below.
+                    targetEpoch = earlierEpochEndOffset.leaderEpoch();
+                } else {
+                    targetEpoch = epochForLeaderLocalLogStartOffset;
+                }
+            }
+
+            Optional<RemoteLogSegmentMetadata> maybeRlsm = rlm.fetchRemoteLogSegmentMetadata(topicPartition, targetEpoch, previousOffsetToLeaderLocalLogStartOffset);
+
+            if (maybeRlsm.isPresent()) {
+                RemoteLogSegmentMetadata remoteLogSegmentMetadata = maybeRlsm.get();
+                // Build leader epoch cache, producer snapshots until remoteLogSegmentMetadata.endOffset() and start
+                // segments from (remoteLogSegmentMetadata.endOffset() + 1)
+                // Assign nextOffset with the offset from which next fetch should happen.
+                nextOffset = remoteLogSegmentMetadata.endOffset() + 1;
+
+                // Truncate the existing local log before restoring the leader epoch cache and producer snapshots.
+                Partition partition = replicaMgr.getPartitionOrException(topicPartition);
+                partition.truncateFullyAndStartAt(nextOffset, false);
+
+                // Build leader epoch cache.
+                unifiedLog.maybeIncrementLogStartOffset(leaderLogStartOffset, LeaderOffsetIncremented$.MODULE$);
+                List<EpochEntry> epochs = readLeaderEpochCheckpoint(rlm, remoteLogSegmentMetadata);
+                if (unifiedLog.leaderEpochCache().isDefined()) {
+                    unifiedLog.leaderEpochCache().get().assign(epochs);
+                }
+
+                log.debug("Updated the epoch cache from remote tier till offset: {} with size: {} for {}", leaderLocalLogStartOffset, epochs.size(), partition);
+
+                // Restore producer snapshot
+                File snapshotFile = LogFileUtils.producerSnapshotFile(unifiedLog.dir(), nextOffset);
+                buildProducerSnapshotFile(snapshotFile, remoteLogSegmentMetadata, rlm);
+
+                // Reload producer snapshots.
+                unifiedLog.producerStateManager().truncateFullyAndReloadSnapshots();
+                unifiedLog.loadProducerState(nextOffset);
+                log.debug("Built the leader epoch cache and producer snapshots from remote tier for {}, " +
+                                "with active producers size: {}, leaderLogStartOffset: {}, and logEndOffset: {}",
+                        partition, unifiedLog.producerStateManager().activeProducers().size(), leaderLogStartOffset, nextOffset);
+            } else {
+                throw new RemoteStorageException("Couldn't build the state from remote store for partition: " + topicPartition +
+                        ", currentLeaderEpoch: " + currentLeaderEpoch +
+                        ", leaderLocalLogStartOffset: " + leaderLocalLogStartOffset +
+                        ", leaderLogStartOffset: " + leaderLogStartOffset +
+                        ", epoch: " + targetEpoch +
+                        "as the previous remote log segment metadata was not found");
+            }
+        } else {
+            // If the tiered storage is not enabled throw an exception back so that it will retry until the tiered storage
+            // is set as expected.
+            throw new RemoteStorageException("Couldn't build the state from remote store for partition " + topicPartition + ", as remote log storage is not yet enabled");
+        }
+
+        return nextOffset;
+    }
+}
diff --git a/core/src/main/java/kafka/server/TierStateMachine.java b/core/src/main/java/kafka/server/TierStateMachine.java
new file mode 100644
index 00000000000..58a44cc6472
--- /dev/null
+++ b/core/src/main/java/kafka/server/TierStateMachine.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.server;
+
+import java.util.Optional;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.message.FetchResponseData.PartitionData;
+
+/**
+ * This interface defines the APIs needed to handle any state transitions related to tiering
+ */
+public interface TierStateMachine {
+
+    /**
+     * Start the tier state machine for the provided topic partition.
+     *
+     * @param topicPartition the topic partition
+     * @param currentFetchState the current PartitionFetchState which will
+     *                          be used to derive the return value
+     * @param fetchPartitionData the data from the fetch response that returned the offset moved to tiered storage error
+     *
+     * @return the new PartitionFetchState after the successful start of the
+     *         tier state machine
+     */
+    PartitionFetchState start(TopicPartition topicPartition,
+                              PartitionFetchState currentFetchState,
+                              PartitionData fetchPartitionData) throws Exception;
+
+    /**
+     * Optionally advance the state of the tier state machine, based on the
+     * current PartitionFetchState. The decision to advance the tier
+     * state machine is implementation specific.
+     *
+     * @param topicPartition the topic partition
+     * @param currentFetchState the current PartitionFetchState which will
+     *                          be used to derive the return value
+     *
+     * @return the new PartitionFetchState if the tier state machine was advanced, otherwise, return the currentFetchState
+     */
+    Optional<PartitionFetchState> maybeAdvanceState(TopicPartition topicPartition,
+                                                    PartitionFetchState currentFetchState);
+}
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
index 25f2e292cc5..2176ee3518b 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
@@ -26,6 +26,7 @@ import kafka.utils.{DelayedItem, Logging, Pool}
 import org.apache.kafka.common.errors._
 import org.apache.kafka.common.internals.PartitionStates
 import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
+import org.apache.kafka.common.message.FetchResponseData.PartitionData
 import org.apache.kafka.common.message.{FetchResponseData, OffsetForLeaderEpochRequestData}
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.{FileRecords, MemoryRecords, Records}
@@ -54,6 +55,7 @@ abstract class AbstractFetcherThread(name: String,
                                      clientId: String,
                                      val leader: LeaderEndPoint,
                                      failedPartitions: FailedPartitions,
+                                     val fetchTierStateMachine: TierStateMachine,
                                      fetchBackOffMs: Int = 0,
                                      isInterruptible: Boolean = true,
                                      val brokerTopicStats: BrokerTopicStats) //BrokerTopicStats's lifecycle managed by ReplicaManager
@@ -93,22 +95,6 @@ abstract class AbstractFetcherThread(name: String,
 
   protected val isOffsetForLeaderEpochSupported: Boolean
 
-  /**
-   * Builds the required remote log auxiliary state for the given topic partition on this follower replica and returns
-   * the offset to be fetched from the leader replica.
-   *
-   * @param partition topic partition
-   * @param currentLeaderEpoch current leader epoch maintained by this follower replica.
-   * @param fetchOffset offset to be fetched from the leader.
-   * @param epochForFetchOffset respective leader epoch for the given fetch pffset.
-   * @param leaderLogStartOffset log-start-offset on the leader.
-   */
-  protected def buildRemoteLogAuxState(partition: TopicPartition,
-                                       currentLeaderEpoch: Int,
-                                       fetchOffset: Long,
-                                       epochForFetchOffset: Int,
-                                       leaderLogStartOffset: Long): Long
-
   override def shutdown(): Unit = {
     initiateShutdown()
     inLock(partitionMapLock) {
@@ -410,12 +396,7 @@ abstract class AbstractFetcherThread(name: String,
                 case Errors.OFFSET_OUT_OF_RANGE =>
                   if (!handleOutOfRangeError(topicPartition, currentFetchState, fetchPartitionData.currentLeaderEpoch))
                     partitionsWithError += topicPartition
-                case Errors.OFFSET_MOVED_TO_TIERED_STORAGE =>
-                  debug(s"Received error ${Errors.OFFSET_MOVED_TO_TIERED_STORAGE}, " +
-                    s"at fetch offset: ${currentFetchState.fetchOffset}, " + s"topic-partition: $topicPartition")
-                  if (!handleOffsetsMovedToTieredStorage(topicPartition, currentFetchState,
-                    fetchPartitionData.currentLeaderEpoch, partitionData.logStartOffset()))
-                    partitionsWithError += topicPartition
+
                 case Errors.UNKNOWN_LEADER_EPOCH =>
                   debug(s"Remote broker has a smaller leader epoch for partition $topicPartition than " +
                     s"this replica's current leader epoch of ${currentFetchState.currentLeaderEpoch}.")
@@ -425,6 +406,12 @@ abstract class AbstractFetcherThread(name: String,
                   if (onPartitionFenced(topicPartition, fetchPartitionData.currentLeaderEpoch))
                     partitionsWithError += topicPartition
 
+                case Errors.OFFSET_MOVED_TO_TIERED_STORAGE =>
+                  debug(s"Received error ${Errors.OFFSET_MOVED_TO_TIERED_STORAGE}, " +
+                    s"at fetch offset: ${currentFetchState.fetchOffset}, " + s"topic-partition: $topicPartition")
+                  if (!handleOffsetsMovedToTieredStorage(topicPartition, currentFetchState, fetchPartitionData.currentLeaderEpoch, partitionData))
+                    partitionsWithError += topicPartition
+
                 case Errors.NOT_LEADER_OR_FOLLOWER =>
                   debug(s"Remote broker is not the leader for partition $topicPartition, which could indicate " +
                     "that the partition is being moved")
@@ -644,25 +631,9 @@ abstract class AbstractFetcherThread(name: String,
   }
 
   /**
-   * It returns the next fetch state. It fetches the  log-start-offset or local-log-start-offset based on
-   * `fetchFromLocalLogStartOffset` flag. This is used in truncation by passing it to the given `truncateAndBuild`
-   * function.
-   *
-   * @param topicPartition               topic partition
-   * @param topicId                      topic id
-   * @param currentLeaderEpoch           current leader epoch maintained by this follower replica.
-   * @param truncateAndBuild             Function to truncate for the given epoch and offset. It returns the next fetch offset value.
-   * @param fetchFromLocalLogStartOffset Whether to fetch from local-log-start-offset or log-start-offset. If true, it
-   *                                     requests the local-log-start-offset from the leader, else it requests
-   *                                     log-start-offset from the leader. This is used in sending the value to the
-   *                                     given `truncateAndBuild` function.
-   * @return next PartitionFetchState
+   * Handle a partition whose offset is out of range and return a new fetch offset.
    */
-  private def fetchOffsetAndApplyTruncateAndBuild(topicPartition: TopicPartition,
-                                                  topicId: Option[Uuid],
-                                                  currentLeaderEpoch: Int,
-                                                  truncateAndBuild: => OffsetAndEpoch => Long,
-                                                  fetchFromLocalLogStartOffset: Boolean = true): PartitionFetchState = {
+  private def fetchOffsetAndTruncate(topicPartition: TopicPartition, topicId: Option[Uuid], currentLeaderEpoch: Int): PartitionFetchState = {
     val replicaEndOffset = logEndOffset(topicPartition)
 
     /**
@@ -696,33 +667,25 @@ abstract class AbstractFetcherThread(name: String,
        * produced to the new leader. While the old leader is trying to handle the OffsetOutOfRangeException and query
        * the log end offset of the new leader, the new leader's log end offset becomes higher than the follower's log end offset.
        *
-       * In the first case, the follower's current log end offset is smaller than the leader's log start offset
-       * (or leader's local log start offset).
-       * So the follower should truncate all its logs, roll out a new segment and start to fetch from the current
-       * leader's log start offset(or leader's local log start offset).
+       * In the first case, if the follower's current log end offset is smaller than the leader's log start offset, the
+       * follower should truncate all its logs, roll out a new segment and start to fetch from the current leader's log
+       * start offset since the data are all stale.
        * In the second case, the follower should just keep the current log segments and retry the fetch. In the second
        * case, there will be some inconsistency of data between old and new leader. We are not solving it here.
        * If users want to have strong consistency guarantees, appropriate configurations needs to be set for both
        * brokers and producers.
        *
        * Putting the two cases together, the follower should fetch from the higher one of its replica log end offset
-       * and the current leader's (local-log-start-offset or) log start offset.
+       * and the current leader's log start offset.
        */
-      val offsetAndEpoch = if (fetchFromLocalLogStartOffset)
-        leader.fetchEarliestLocalOffset(topicPartition, currentLeaderEpoch) else
-        leader.fetchEarliestOffset(topicPartition, currentLeaderEpoch)
+      val offsetAndEpoch = leader.fetchEarliestOffset(topicPartition, currentLeaderEpoch)
       val leaderStartOffset = offsetAndEpoch.offset
       warn(s"Reset fetch offset for partition $topicPartition from $replicaEndOffset to current " +
         s"leader's start offset $leaderStartOffset")
-      val offsetToFetch =
-        if (leaderStartOffset > replicaEndOffset) {
-          // Only truncate log when current leader's log start offset (local log start offset if >= 3.4 version incaseof
-          // OffsetMovedToTieredStorage error) is greater than follower's log end offset.
-          // truncateAndBuild returns offset value from which it needs to start fetching.
-          truncateAndBuild(offsetAndEpoch)
-        } else {
-          replicaEndOffset
-        }
+      val offsetToFetch = Math.max(leaderStartOffset, replicaEndOffset)
+      // Only truncate log when current leader's log start offset is greater than follower's log end offset.
+      if (leaderStartOffset > replicaEndOffset)
+        truncateFullyAndStartAt(topicPartition, leaderStartOffset)
 
       val initialLag = leaderEndOffset - offsetToFetch
       fetcherLagStats.getAndMaybePut(topicPartition).lag = initialLag
@@ -731,23 +694,6 @@ abstract class AbstractFetcherThread(name: String,
     }
   }
 
-  /**
-   * Handle a partition whose offset is out of range and return a new fetch offset.
-   */
-  private def fetchOffsetAndTruncate(topicPartition: TopicPartition, topicId: Option[Uuid], currentLeaderEpoch: Int): PartitionFetchState = {
-    fetchOffsetAndApplyTruncateAndBuild(topicPartition, topicId, currentLeaderEpoch,
-      offsetAndEpoch => {
-        val leaderLogStartOffset = offsetAndEpoch.offset
-        truncateFullyAndStartAt(topicPartition, leaderLogStartOffset)
-        leaderLogStartOffset
-      },
-      // In this case, it will fetch from leader's log-start-offset like earlier instead of fetching from
-      // local-log-start-offset. This handles both the scenarios of whether tiered storage is enabled or not.
-      // If tiered storage is enabled, the next fetch result of fetching from log-start-offset may result in
-      // OffsetMovedToTieredStorage error and it will handle building the remote log state.
-      fetchFromLocalLogStartOffset = false)
-  }
-
   /**
    * Handles the out of range error for the given topic partition.
    *
@@ -798,21 +744,20 @@ abstract class AbstractFetcherThread(name: String,
    * Returns false if there was a retriable error.
    *
    * @param topicPartition topic partition
-   * @param fetchState current partition fetch state.
-   * @param leaderEpochInRequest current leader epoch sent in the fetch request.
-   * @param leaderLogStartOffset log-start-offset in the leader replica.
+   * @param fetchState current partition fetch state
+   * @param leaderEpochInRequest current leader epoch sent in the fetch request
+   * @param fetchPartitionData the fetch response data for this topic partition
    */
   private def handleOffsetsMovedToTieredStorage(topicPartition: TopicPartition,
                                                 fetchState: PartitionFetchState,
                                                 leaderEpochInRequest: Optional[Integer],
-                                                leaderLogStartOffset: Long): Boolean = {
+                                                fetchPartitionData: PartitionData): Boolean = {
     try {
-      val newFetchState = fetchOffsetAndApplyTruncateAndBuild(topicPartition, fetchState.topicId, fetchState.currentLeaderEpoch,
-        offsetAndEpoch => {
-          val leaderLocalLogStartOffset = offsetAndEpoch.offset
-          buildRemoteLogAuxState(topicPartition, fetchState.currentLeaderEpoch, leaderLocalLogStartOffset, offsetAndEpoch.leaderEpoch(), leaderLogStartOffset)
-        })
+      val newFetchState = fetchTierStateMachine.start(topicPartition, fetchState, fetchPartitionData);
+
+      // TODO: use fetchTierStateMachine.maybeAdvanceState when implementing async tiering logic in KAFKA-13560
 
+      fetcherLagStats.getAndMaybePut(topicPartition).lag = newFetchState.lag.getOrElse(0)
       partitionStates.updateAndMoveToEnd(topicPartition, newFetchState)
       debug(s"Current offset ${fetchState.fetchOffset} for partition $topicPartition is " +
         s"out of range or moved to remote tier. Reset fetch offset to ${newFetchState.fetchOffset}")
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
index e003ae1c76f..cae9193fba1 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
@@ -36,6 +36,7 @@ class ReplicaAlterLogDirsThread(name: String,
                                 clientId = name,
                                 leader = leader,
                                 failedPartitions,
+                                fetchTierStateMachine = new ReplicaAlterLogDirsTierStateMachine(),
                                 fetchBackOffMs = fetchBackOffMs,
                                 isInterruptible = false,
                                 brokerTopicStats) {
@@ -122,14 +123,4 @@ class ReplicaAlterLogDirsThread(name: String,
     val partition = replicaMgr.getPartitionOrException(topicPartition)
     partition.truncateFullyAndStartAt(offset, isFuture = true)
   }
-
-  override protected def buildRemoteLogAuxState(partition: TopicPartition,
-                                                currentLeaderEpoch: Int,
-                                                fetchOffset: Long,
-                                                epochForFetchOffset: Int,
-                                                leaderLogStartOffset: Long): Long = {
-    // JBOD is not supported with tiered storage.
-    throw new UnsupportedOperationException();
-  }
-
 }
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index 4a653c43552..ae75fb571b4 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -17,24 +17,14 @@
 
 package kafka.server
 
-import kafka.log.remote.RemoteLogManager
 import kafka.log.LeaderOffsetIncremented
-import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
-import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.MemoryRecords
-import org.apache.kafka.common.requests._
-import org.apache.kafka.common.utils.Utils
-import org.apache.kafka.common.{KafkaException, TopicPartition}
-import org.apache.kafka.server.common.CheckpointFile.CheckpointReadBuffer
+import org.apache.kafka.common.requests.FetchResponse
+import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.server.common.OffsetAndEpoch
 import org.apache.kafka.server.common.MetadataVersion
-import org.apache.kafka.server.log.remote.storage.{RemoteLogSegmentMetadata, RemoteStorageException, RemoteStorageManager}
-import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
-import org.apache.kafka.storage.internals.log.{EpochEntry, LogAppendInfo, LogFileUtils}
+import org.apache.kafka.storage.internals.log.LogAppendInfo
 
-import java.io.{BufferedReader, File, InputStreamReader}
-import java.nio.charset.StandardCharsets
-import java.nio.file.{Files, StandardCopyOption}
 import scala.collection.mutable
 
 class ReplicaFetcherThread(name: String,
@@ -49,6 +39,7 @@ class ReplicaFetcherThread(name: String,
                                 clientId = name,
                                 leader = leader,
                                 failedPartitions,
+                                fetchTierStateMachine = new ReplicaFetcherTierStateMachine(leader, replicaMgr),
                                 fetchBackOffMs = brokerConfig.replicaFetchBackoffMs,
                                 isInterruptible = false,
                                 replicaMgr.brokerTopicStats) {
@@ -203,141 +194,4 @@ class ReplicaFetcherThread(name: String,
     val partition = replicaMgr.getPartitionOrException(topicPartition)
     partition.truncateFullyAndStartAt(offset, isFuture = false)
   }
-
-  private def buildProducerSnapshotFile(snapshotFile: File, remoteLogSegmentMetadata: RemoteLogSegmentMetadata, rlm: RemoteLogManager): Unit = {
-    val tmpSnapshotFile = new File(snapshotFile.getAbsolutePath + ".tmp")
-    // Copy it to snapshot file in atomic manner.
-    Files.copy(rlm.storageManager().fetchIndex(remoteLogSegmentMetadata, RemoteStorageManager.IndexType.PRODUCER_SNAPSHOT),
-      tmpSnapshotFile.toPath, StandardCopyOption.REPLACE_EXISTING)
-    Utils.atomicMoveWithFallback(tmpSnapshotFile.toPath, snapshotFile.toPath, false)
-  }
-
-  /**
-   * It tries to build the required state for this partition from leader and remote storage so that it can start
-   * fetching records from the leader.
-   */
-  override protected def buildRemoteLogAuxState(partition: TopicPartition,
-                                                currentLeaderEpoch: Int,
-                                                leaderLocalLogStartOffset: Long,
-                                                epochForLeaderLocalLogStartOffset: Int,
-                                                leaderLogStartOffset: Long): Long = {
-
-    def fetchEarlierEpochEndOffset(epoch: Int): EpochEndOffset = {
-      val previousEpoch = epoch - 1
-      // Find the end-offset for the epoch earlier to the given epoch from the leader
-      val partitionsWithEpochs = Map(partition -> new EpochData().setPartition(partition.partition())
-        .setCurrentLeaderEpoch(currentLeaderEpoch)
-        .setLeaderEpoch(previousEpoch))
-      val maybeEpochEndOffset = leader.fetchEpochEndOffsets(partitionsWithEpochs).get(partition)
-      if (maybeEpochEndOffset.isEmpty) {
-        throw new KafkaException("No response received for partition: " + partition);
-      }
-
-      val epochEndOffset = maybeEpochEndOffset.get
-      if (epochEndOffset.errorCode() != Errors.NONE.code()) {
-        throw Errors.forCode(epochEndOffset.errorCode()).exception()
-      }
-
-      epochEndOffset
-    }
-
-    val log = replicaMgr.localLogOrException(partition)
-    val nextOffset = {
-      if (log.remoteStorageSystemEnable && log.config.remoteLogConfig.remoteStorageEnable) {
-        if (replicaMgr.remoteLogManager.isEmpty) throw new IllegalStateException("RemoteLogManager is not yet instantiated")
-
-        val rlm = replicaMgr.remoteLogManager.get
-
-        // Find the respective leader epoch for (leaderLocalLogStartOffset - 1). We need to build the leader epoch cache
-        // until that offset
-        val previousOffsetToLeaderLocalLogStartOffset = leaderLocalLogStartOffset - 1
-        val targetEpoch: Int = {
-          // If the existing epoch is 0, no need to fetch from earlier epoch as the desired offset(leaderLogStartOffset - 1)
-          // will have the same epoch.
-          if (epochForLeaderLocalLogStartOffset == 0) {
-            epochForLeaderLocalLogStartOffset
-          } else {
-            // Fetch the earlier epoch/end-offset(exclusive) from the leader.
-            val earlierEpochEndOffset = fetchEarlierEpochEndOffset(epochForLeaderLocalLogStartOffset)
-            // Check if the target offset lies with in the range of earlier epoch. Here, epoch's end-offset is exclusive.
-            if (earlierEpochEndOffset.endOffset > previousOffsetToLeaderLocalLogStartOffset) {
-              // Always use the leader epoch from returned earlierEpochEndOffset.
-              // This gives the respective leader epoch, that will handle any gaps in epochs.
-              // For ex, leader epoch cache contains:
-              // leader-epoch   start-offset
-              //  0 		          20
-              //  1 		          85
-              //  <2> - gap no messages were appended in this leader epoch.
-              //  3 		          90
-              //  4 		          98
-              // There is a gap in leader epoch. For leaderLocalLogStartOffset as 90, leader-epoch is 3.
-              // fetchEarlierEpochEndOffset(2) will return leader-epoch as 1, end-offset as 90.
-              // So, for offset 89, we should return leader epoch as 1 like below.
-              earlierEpochEndOffset.leaderEpoch()
-            } else epochForLeaderLocalLogStartOffset
-          }
-        }
-
-        val maybeRlsm = rlm.fetchRemoteLogSegmentMetadata(partition, targetEpoch, previousOffsetToLeaderLocalLogStartOffset)
-
-        if (maybeRlsm.isPresent) {
-          val remoteLogSegmentMetadata = maybeRlsm.get()
-          // Build leader epoch cache, producer snapshots until remoteLogSegmentMetadata.endOffset() and start
-          // segments from (remoteLogSegmentMetadata.endOffset() + 1)
-          val nextOffset = remoteLogSegmentMetadata.endOffset() + 1
-
-          // Truncate the existing local log before restoring the leader epoch cache and producer snapshots.
-          truncateFullyAndStartAt(partition, nextOffset)
-
-          // Build leader epoch cache.
-          log.maybeIncrementLogStartOffset(leaderLogStartOffset, LeaderOffsetIncremented)
-          val epochs = readLeaderEpochCheckpoint(rlm, remoteLogSegmentMetadata)
-          log.leaderEpochCache.foreach { cache =>
-            cache.assign(epochs)
-          }
-
-          debug(s"Updated the epoch cache from remote tier till offset: $leaderLocalLogStartOffset " +
-            s"with size: ${epochs.size} for $partition")
-
-          // Restore producer snapshot
-          val snapshotFile = LogFileUtils.producerSnapshotFile(log.dir, nextOffset)
-          buildProducerSnapshotFile(snapshotFile, remoteLogSegmentMetadata, rlm)
-
-          // Reload producer snapshots.
-          log.producerStateManager.truncateFullyAndReloadSnapshots()
-          log.loadProducerState(nextOffset)
-          debug(s"Built the leader epoch cache and producer snapshots from remote tier for $partition, with " +
-            s"active producers size: ${log.producerStateManager.activeProducers.size}, " +
-            s"leaderLogStartOffset: $leaderLogStartOffset, and logEndOffset: $nextOffset")
-
-          // Return the offset from which next fetch should happen.
-          nextOffset
-        } else {
-          throw new RemoteStorageException(s"Couldn't build the state from remote store for partition: $partition, " +
-            s"currentLeaderEpoch: $currentLeaderEpoch, leaderLocalLogStartOffset: $leaderLocalLogStartOffset, " +
-            s"leaderLogStartOffset: $leaderLogStartOffset, epoch: $targetEpoch as the previous remote log segment " +
-            s"metadata was not found")
-        }
-      } else {
-        // If the tiered storage is not enabled throw an exception back so tht it will retry until the tiered storage
-        // is set as expected.
-        throw new RemoteStorageException(s"Couldn't build the state from remote store for partition $partition, as " +
-          s"remote log storage is not yet enabled")
-      }
-    }
-
-    nextOffset
-  }
-
-  private def readLeaderEpochCheckpoint(rlm: RemoteLogManager, remoteLogSegmentMetadata: RemoteLogSegmentMetadata): java.util.List[EpochEntry] = {
-    val inputStream = rlm.storageManager().fetchIndex(remoteLogSegmentMetadata, RemoteStorageManager.IndexType.LEADER_EPOCH)
-    val bufferedReader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))
-    try {
-      val readBuffer = new CheckpointReadBuffer[EpochEntry]("", bufferedReader,  0, LeaderEpochCheckpointFile.FORMATTER)
-      readBuffer.read()
-    } finally {
-      bufferedReader.close()
-    }
-  }
-
 }
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
index 25bf9633438..11b7ceb2df4 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
@@ -21,6 +21,7 @@ import kafka.cluster.BrokerEndPoint
 import kafka.server.AbstractFetcherThread.{ReplicaFetch, ResultWithPartitions}
 import kafka.utils.Implicits.MapExtensionMethods
 import kafka.utils.TestUtils
+import org.apache.kafka.common.message.FetchResponseData.PartitionData
 import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
 import org.apache.kafka.common.requests.FetchRequest
 import org.apache.kafka.common.utils.Utils
@@ -32,6 +33,7 @@ import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{BeforeEach, Test}
 import org.mockito.Mockito.{mock, verify, when}
 
+import java.util.Optional
 import scala.collection.{Map, Set, mutable}
 import scala.jdk.CollectionConverters._
 
@@ -235,7 +237,7 @@ class AbstractFetcherManagerTest {
     val failedTopicPartitions = makeTopicPartition(2, 5, "topic_failed")
     val fetcherManager = new AbstractFetcherManager[AbstractFetcherThread]("fetcher-manager", "fetcher-manager", currentFetcherSize) {
       override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): AbstractFetcherThread = {
-        new TestResizeFetcherThread(sourceBroker, failedPartitions)
+        new TestResizeFetcherThread(sourceBroker, failedPartitions, new MockResizeFetcherTierStateMachine)
       }
     }
     try {
@@ -311,12 +313,21 @@ class AbstractFetcherManagerTest {
     override def fetchEarliestLocalOffset(topicPartition: TopicPartition, currentLeaderEpoch: Int): OffsetAndEpoch = new OffsetAndEpoch(1L, 0)
   }
 
-  private class TestResizeFetcherThread(sourceBroker: BrokerEndPoint, failedPartitions: FailedPartitions)
+  private class MockResizeFetcherTierStateMachine extends TierStateMachine {
+    override def start(topicPartition: TopicPartition, currentFetchState: PartitionFetchState, fetchPartitionData: PartitionData): PartitionFetchState = {
+      throw new UnsupportedOperationException("Materializing tier state is not supported in this test.")
+    }
+
+    override def maybeAdvanceState(tp: TopicPartition, currentFetchState: PartitionFetchState): Optional[PartitionFetchState] = Optional.empty[PartitionFetchState]
+  }
+
+  private class TestResizeFetcherThread(sourceBroker: BrokerEndPoint, failedPartitions: FailedPartitions, fetchTierStateMachine: TierStateMachine)
     extends AbstractFetcherThread(
       name = "test-resize-fetcher",
       clientId = "mock-fetcher",
       leader = new MockLeaderEndPoint(sourceBroker),
       failedPartitions,
+      fetchTierStateMachine,
       fetchBackOffMs = 0,
       brokerTopicStats = new BrokerTopicStats) {
 
@@ -337,8 +348,6 @@ class AbstractFetcherManagerTest {
     override protected def endOffsetForEpoch(topicPartition: TopicPartition, epoch: Int): Option[OffsetAndEpoch] = Some(new OffsetAndEpoch(1, 0))
 
     override protected val isOffsetForLeaderEpochSupported: Boolean = false
-
-    override protected def buildRemoteLogAuxState(partition: TopicPartition, currentLeaderEpoch: Int, fetchOffset: Long, epochForFetchOffset: Int, leaderLogStartOffset: Long): Long = 1
   }
 
 }
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
index 9c8d5323c97..e959251b6ea 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
@@ -76,7 +76,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testMetricsRemovedOnShutdown(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint()
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     // add one partition to create the consumer lag metric
     fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
@@ -104,7 +106,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testConsumerLagRemovedWithPartition(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint()
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     // add one partition to create the consumer lag metric
     fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
@@ -127,7 +131,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testSimpleFetch(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint()
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)))
@@ -150,11 +156,13 @@ class AbstractFetcherThreadTest {
     val partition = new TopicPartition("topic", 0)
     val fetchBackOffMs = 250
 
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+    val mockLeaderEndpoint = new MockLeaderEndPoint {
       override def fetch(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
         throw new UnknownTopicIdException("Topic ID was unknown as expected for this test")
       }
-    }, fetchBackOffMs = fetchBackOffMs)
+    }
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine, fetchBackOffMs = fetchBackOffMs)
 
     fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0)))
@@ -190,13 +198,15 @@ class AbstractFetcherThreadTest {
     val partition3 = new TopicPartition("topic3", 0)
     val fetchBackOffMs = 250
 
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+    val mockLeaderEndPoint = new MockLeaderEndPoint {
       override def fetch(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
         Map(partition1 -> new FetchData().setErrorCode(Errors.UNKNOWN_TOPIC_ID.code),
           partition2 -> new FetchData().setErrorCode(Errors.INCONSISTENT_TOPIC_ID.code),
           partition3 -> new FetchData().setErrorCode(Errors.NONE.code))
       }
-    }, fetchBackOffMs = fetchBackOffMs)
+    }
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndPoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndPoint, mockTierStateMachine, fetchBackOffMs = fetchBackOffMs)
 
     fetcher.setReplicaState(partition1, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition1 -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0)))
@@ -231,7 +241,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testFencedTruncation(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint()
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)))
@@ -257,7 +269,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testFencedFetch(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint()
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     val replicaState = PartitionState(leaderEpoch = 0)
     fetcher.setReplicaState(partition, replicaState)
@@ -288,7 +302,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testUnknownLeaderEpochInTruncation(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint()
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     // The replica's leader epoch is ahead of the leader
     val replicaState = PartitionState(leaderEpoch = 1)
@@ -319,7 +335,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testUnknownLeaderEpochWhileFetching(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint()
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     // This test is contrived because it shouldn't be possible to to see unknown leader epoch
     // in the Fetching state as the leader must validate the follower's epoch when it checks
@@ -360,7 +378,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testTruncation(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint()
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
@@ -394,11 +414,14 @@ class AbstractFetcherThreadTest {
   def testTruncateToHighWatermarkIfLeaderEpochRequestNotSupported(): Unit = {
     val highWatermark = 2L
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
-        override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] =
-          throw new UnsupportedOperationException
-        override val isTruncationOnFetchSupported: Boolean = false
-    }) {
+    val mockLeaderEndPoint = new MockLeaderEndPoint {
+      override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] =
+        throw new UnsupportedOperationException
+
+      override val isTruncationOnFetchSupported: Boolean = false
+    }
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndPoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndPoint, mockTierStateMachine) {
         override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
           assertEquals(highWatermark, truncationState.offset)
           assertTrue(truncationState.truncationCompleted)
@@ -428,10 +451,12 @@ class AbstractFetcherThreadTest {
   def testTruncateToHighWatermarkIfLeaderEpochInfoNotAvailable(): Unit = {
     val highWatermark = 2L
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
-        override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] =
-          throw new UnsupportedOperationException
-      }) {
+    val mockLeaderEndPoint = new MockLeaderEndPoint {
+      override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] =
+        throw new UnsupportedOperationException
+    }
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndPoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndPoint, mockTierStateMachine) {
         override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
           assertEquals(highWatermark, truncationState.offset)
           assertTrue(truncationState.truncationCompleted)
@@ -462,7 +487,10 @@ class AbstractFetcherThreadTest {
   def testTruncateToHighWatermarkDuringRemovePartitions(): Unit = {
     val highWatermark = 2L
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint) {
+
+    val mockLeaderEndpoint = new MockLeaderEndPoint()
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine) {
       override def truncateToHighWatermark(partitions: Set[TopicPartition]): Unit = {
         removePartitions(Set(partition))
         super.truncateToHighWatermark(partitions)
@@ -492,7 +520,9 @@ class AbstractFetcherThreadTest {
     val partition = new TopicPartition("topic", 0)
 
     var truncations = 0
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint) {
+    val mockLeaderEndpoint = new MockLeaderEndPoint()
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine) {
       override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
         truncations += 1
         super.truncate(topicPartition, truncationState)
@@ -535,7 +565,9 @@ class AbstractFetcherThreadTest {
     assumeTrue(truncateOnFetch)
     val partition = new TopicPartition("topic", 0)
     var truncations = 0
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint) {
+    val mockLeaderEndpoint = new MockLeaderEndPoint
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine) {
       override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
         truncations += 1
         super.truncate(topicPartition, truncationState)
@@ -575,7 +607,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testFollowerFetchOutOfRangeHigh(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
@@ -616,7 +650,6 @@ class AbstractFetcherThreadTest {
   @Test
   def testFollowerFetchMovedToTieredStore(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
@@ -624,6 +657,11 @@ class AbstractFetcherThreadTest {
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
     val replicaState = PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 0L, rlmEnabled = true)
+
+    val mockLeaderEndpoint = new MockLeaderEndPoint
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
+
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 5)))
 
@@ -633,7 +671,6 @@ class AbstractFetcherThreadTest {
       mkBatch(baseOffset = 7, leaderEpoch = 5, new SimpleRecord("h".getBytes)),
       mkBatch(baseOffset = 8, leaderEpoch = 5, new SimpleRecord("i".getBytes)))
 
-
     val leaderState = PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 8L, rlmEnabled = true)
     // Overriding the log start offset to zero for mocking the scenario of segment 0-4 moved to remote store.
     leaderState.logStartOffset = 0
@@ -664,16 +701,14 @@ class AbstractFetcherThreadTest {
   def testFencedOffsetResetAfterMovedToRemoteTier(): Unit = {
     val partition = new TopicPartition("topic", 0)
     var isErrorHandled = false
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint) {
-      override protected def buildRemoteLogAuxState(partition: TopicPartition,
-                                                    currentLeaderEpoch: Int,
-                                                    fetchOffset: Long,
-                                                    epochForFetchOffset: Int,
-                                                    leaderLogStartOffset: Long): Long = {
+    val mockLeaderEndpoint = new MockLeaderEndPoint
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint) {
+      override def start(topicPartition: TopicPartition, currentFetchState: PartitionFetchState, fetchPartitionData: FetchResponseData.PartitionData): PartitionFetchState = {
         isErrorHandled = true
-        throw new FencedLeaderEpochException(s"Epoch $currentLeaderEpoch is fenced")
+        throw new FencedLeaderEpochException(s"Epoch ${currentFetchState.currentLeaderEpoch} is fenced")
       }
     }
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     val replicaLog = Seq(
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
@@ -704,16 +739,19 @@ class AbstractFetcherThreadTest {
     val partition = new TopicPartition("topic", 0)
     var fetchedEarliestOffset = false
 
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+    val mockLeaderEndPoint = new MockLeaderEndPoint {
       override def fetchEarliestOffset(topicPartition: TopicPartition, leaderEpoch: Int): OffsetAndEpoch = {
         fetchedEarliestOffset = true
         throw new FencedLeaderEpochException(s"Epoch $leaderEpoch is fenced")
       }
+
       override def fetchEarliestLocalOffset(topicPartition: TopicPartition, leaderEpoch: Int): OffsetAndEpoch = {
         fetchedEarliestOffset = true
         throw new FencedLeaderEpochException(s"Epoch $leaderEpoch is fenced")
       }
-    })
+    }
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndPoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndPoint, mockTierStateMachine)
 
     val replicaLog = Seq()
     val replicaState = PartitionState(replicaLog, leaderEpoch = 4, highWatermark = 0L)
@@ -738,7 +776,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testFollowerFetchOutOfRangeLow(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     // The follower begins from an offset which is behind the leader's log start offset
     val replicaLog = Seq(
@@ -779,14 +819,16 @@ class AbstractFetcherThreadTest {
   @Test
   def testRetryAfterUnknownLeaderEpochInLatestOffsetFetch(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher: MockFetcherThread = new MockFetcherThread(new MockLeaderEndPoint {
+    val mockLeaderEndPoint = new MockLeaderEndPoint {
       val tries = new AtomicInteger(0)
       override def fetchLatestOffset(topicPartition: TopicPartition, leaderEpoch: Int): OffsetAndEpoch = {
         if (tries.getAndIncrement() == 0)
           throw new UnknownLeaderEpochException("Unexpected leader epoch")
         super.fetchLatestOffset(topicPartition, leaderEpoch)
       }
-    })
+    }
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndPoint)
+    val fetcher: MockFetcherThread = new MockFetcherThread(mockLeaderEndPoint, mockTierStateMachine)
 
     // The follower begins from an offset which is behind the leader's log start offset
     val replicaLog = Seq(
@@ -821,7 +863,7 @@ class AbstractFetcherThreadTest {
   def testCorruptMessage(): Unit = {
     val partition = new TopicPartition("topic", 0)
 
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+    val mockLeaderEndPoint = new MockLeaderEndPoint {
       var fetchedOnce = false
       override def fetch(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
         val fetchedData = super.fetch(fetchRequest)
@@ -834,7 +876,9 @@ class AbstractFetcherThreadTest {
         }
         fetchedData
       }
-    })
+    }
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndPoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndPoint, mockTierStateMachine)
 
     fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)))
@@ -875,8 +919,9 @@ class AbstractFetcherThreadTest {
     val initialLeaderEpochOnFollower = 0
     val nextLeaderEpochOnFollower = initialLeaderEpochOnFollower + 1
 
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+    val mockLeaderEndpoint = new MockLeaderEndPoint {
       var fetchEpochsFromLeaderOnce = false
+
       override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
         val fetchedEpochs = super.fetchEpochEndOffsets(partitions)
         if (!fetchEpochsFromLeaderOnce) {
@@ -885,7 +930,9 @@ class AbstractFetcherThreadTest {
         }
         fetchedEpochs
       }
-    })
+    }
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     def changeLeaderEpochWhileFetchEpoch(): Unit = {
       fetcher.removePartitions(Set(partition))
@@ -928,13 +975,15 @@ class AbstractFetcherThreadTest {
     val initialLeaderEpochOnFollower = 0
     val nextLeaderEpochOnFollower = initialLeaderEpochOnFollower + 1
 
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+    val mockLeaderEndpoint = new MockLeaderEndPoint {
       override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
         val fetchedEpochs = super.fetchEpochEndOffsets(partitions)
         responseCallback.apply()
         fetchedEpochs
       }
-    })
+    }
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     def changeLeaderEpochDuringFetchEpoch(): Unit = {
       // leader epoch changes while fetching epochs from leader
@@ -972,7 +1021,7 @@ class AbstractFetcherThreadTest {
   @Test
   def testTruncationThrowsExceptionIfLeaderReturnsPartitionsNotRequestedInFetchEpochs(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+    val mockLeaderEndPoint = new MockLeaderEndPoint {
       override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
         val unrequestedTp = new TopicPartition("topic2", 0)
         super.fetchEpochEndOffsets(partitions).toMap + (unrequestedTp -> new EpochEndOffset()
@@ -981,7 +1030,9 @@ class AbstractFetcherThreadTest {
           .setLeaderEpoch(0)
           .setEndOffset(0))
       }
-    })
+    }
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndPoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndPoint, mockTierStateMachine)
 
     fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)), forceTruncation = true)
@@ -994,7 +1045,9 @@ class AbstractFetcherThreadTest {
 
   @Test
   def testFetcherThreadHandlingPartitionFailureDuringAppending(): Unit = {
-    val fetcherForAppend = new MockFetcherThread(new MockLeaderEndPoint) {
+    val mockLeaderEndpoint = new MockLeaderEndPoint
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcherForAppend = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine) {
       override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = {
         if (topicPartition == partition1) {
           throw new KafkaException()
@@ -1008,7 +1061,9 @@ class AbstractFetcherThreadTest {
 
   @Test
   def testFetcherThreadHandlingPartitionFailureDuringTruncation(): Unit = {
-    val fetcherForTruncation = new MockFetcherThread(new MockLeaderEndPoint) {
+    val mockLeaderEndpoint = new MockLeaderEndPoint
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcherForTruncation = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine) {
       override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
         if(topicPartition == partition1)
           throw new Exception()
@@ -1057,7 +1112,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testDivergingEpochs(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
@@ -1097,7 +1154,9 @@ class AbstractFetcherThreadTest {
 
     var truncateCalls = 0
     var processPartitionDataCalls = 0
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint) {
+    val mockLeaderEndpoint = new MockLeaderEndPoint
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine) {
       override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = {
         processPartitionDataCalls += 1
         super.processPartitionData(topicPartition, fetchOffset, partitionData)
@@ -1163,7 +1222,9 @@ class AbstractFetcherThreadTest {
   @Test
   def testMaybeUpdateTopicIds(): Unit = {
     val partition = new TopicPartition("topic1", 0)
-    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
+    val mockLeaderEndpoint = new MockLeaderEndPoint
+    val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint)
+    val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine)
 
     // Start with no topic IDs
     fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
@@ -1405,6 +1466,30 @@ class AbstractFetcherThreadTest {
     }
   }
 
+  class MockTierStateMachine(leader: LeaderEndPoint) extends ReplicaFetcherTierStateMachine(leader, null) {
+
+    var fetcher : MockFetcherThread = null
+    override def start(topicPartition: TopicPartition,
+                       currentFetchState: PartitionFetchState,
+                       fetchPartitionData: FetchResponseData.PartitionData): PartitionFetchState = {
+      val leaderEndOffset = leader.fetchLatestOffset(topicPartition, currentFetchState.currentLeaderEpoch).offset
+      val offsetToFetch = leader.fetchEarliestLocalOffset(topicPartition, currentFetchState.currentLeaderEpoch).offset
+      val initialLag = leaderEndOffset - offsetToFetch
+      fetcher.truncateFullyAndStartAt(topicPartition, offsetToFetch)
+      PartitionFetchState(currentFetchState.topicId, offsetToFetch, Option.apply(initialLag), currentFetchState.currentLeaderEpoch,
+        Fetching, Some(currentFetchState.currentLeaderEpoch))
+    }
+
+    override def maybeAdvanceState(topicPartition: TopicPartition,
+                                   currentFetchState: PartitionFetchState): Optional[PartitionFetchState] = {
+      Optional.of(currentFetchState)
+    }
+
+    def setFetcher(mockFetcherThread: MockFetcherThread): Unit = {
+      fetcher = mockFetcherThread
+    }
+  }
+
   class PartitionState(var log: mutable.Buffer[RecordBatch],
                        var leaderEpoch: Int,
                        var logStartOffset: Long,
@@ -1425,17 +1510,24 @@ class AbstractFetcherThreadTest {
     }
   }
 
-  class MockFetcherThread(val mockLeader : MockLeaderEndPoint, val replicaId: Int = 0, val leaderId: Int = 1, fetchBackOffMs: Int = 0)
+  class MockFetcherThread(val mockLeader : MockLeaderEndPoint,
+                          val mockTierStateMachine: MockTierStateMachine,
+                          val replicaId: Int = 0,
+                          val leaderId: Int = 1,
+                          fetchBackOffMs: Int = 0)
     extends AbstractFetcherThread("mock-fetcher",
       clientId = "mock-fetcher",
       leader = mockLeader,
       failedPartitions,
+      mockTierStateMachine,
       fetchBackOffMs = fetchBackOffMs,
       brokerTopicStats = new BrokerTopicStats) {
 
     private val replicaPartitionStates = mutable.Map[TopicPartition, PartitionState]()
     private var latestEpochDefault: Option[Int] = Some(0)
 
+    mockTierStateMachine.setFetcher(this)
+
     def setReplicaState(topicPartition: TopicPartition, state: PartitionState): Unit = {
       replicaPartitionStates.put(topicPartition, state)
     }
@@ -1554,18 +1646,6 @@ class AbstractFetcherThreadTest {
     }
 
     override protected val isOffsetForLeaderEpochSupported: Boolean = true
-
-    override protected def buildRemoteLogAuxState(topicPartition: TopicPartition,
-                                                  currentLeaderEpoch: Int,
-                                                  fetchOffset: Long,
-                                                  epochForFetchOffset: Int,
-                                                  leaderLogStartOffset: Long): Long = {
-      truncateFullyAndStartAt(topicPartition, fetchOffset)
-      replicaPartitionState(topicPartition).logStartOffset = leaderLogStartOffset
-      // skipped building leader epoch cache and producer snapshots as they are not verified.
-      leaderLogStartOffset
-    }
-
   }
 
 }