You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ju...@apache.org on 2022/06/03 16:12:27 UTC

[kafka] branch trunk updated: KAFKA-13803: Refactor Leader API Access (#12005)

This is an automated email from the ASF dual-hosted git repository.

junrao pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 3467036e01 KAFKA-13803: Refactor Leader API Access (#12005)
3467036e01 is described below

commit 3467036e017adc3ac0919bbc0c067b1bb1b621f3
Author: Rittika Adhikari <ri...@gmail.com>
AuthorDate: Fri Jun 3 09:12:06 2022 -0700

    KAFKA-13803: Refactor Leader API Access (#12005)
    
    This PR refactors the leader API access in the follower fetch path.
    
    Added a LeaderEndPoint interface which serves all access to the leader.
    
    Added a LocalLeaderEndPoint and a RemoteLeaderEndPoint which implements the LeaderEndPoint interface to handle fetches from leader in local & remote storage respectively.
    
    Reviewers: David Jacot <dj...@confluent.io>, Kowshik Prakasam <kp...@confluent.io>, Jun Rao <ju...@gmail.com>
---
 checkstyle/suppressions.xml                        |   2 +-
 .../kafka/server/AbstractFetcherManager.scala      |   6 +-
 .../scala/kafka/server/AbstractFetcherThread.scala |  33 +-
 ...ockingSend.scala => BrokerBlockingSender.scala} |  22 +-
 core/src/main/scala/kafka/server/KafkaConfig.scala |  32 +-
 .../main/scala/kafka/server/LeaderEndPoint.scala   | 106 +++
 .../scala/kafka/server/LocalLeaderEndPoint.scala   | 236 +++++++
 .../scala/kafka/server/RemoteLeaderEndPoint.scala  | 226 +++++++
 .../kafka/server/ReplicaAlterLogDirsManager.scala  |   5 +-
 .../kafka/server/ReplicaAlterLogDirsThread.scala   | 213 +-----
 .../scala/kafka/server/ReplicaFetcherManager.scala |  13 +-
 .../scala/kafka/server/ReplicaFetcherThread.scala  | 246 +------
 .../kafka/server/AbstractFetcherManagerTest.scala  |  45 +-
 .../kafka/server/AbstractFetcherThreadTest.scala   | 741 +++++++++++----------
 .../server/ReplicaAlterLogDirsThreadTest.scala     | 176 ++---
 .../kafka/server/ReplicaFetcherThreadTest.scala    | 265 +++++---
 .../unit/kafka/server/ReplicaManagerTest.scala     |  28 +-
 .../server/epoch/LeaderEpochIntegrationTest.scala  |   4 +-
 ...BlockingSend.scala => MockBlockingSender.scala} |  14 +-
 .../jmh/fetcher/ReplicaFetcherThreadBenchmark.java | 105 +--
 20 files changed, 1447 insertions(+), 1071 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 10af4e66c2..a111ce74cf 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -311,7 +311,7 @@
               files="(RemoteLogManagerConfig).java"/>
 
     <!-- benchmarks -->
-    <suppress checks="ClassDataAbstractionCoupling"
+    <suppress checks="(ClassDataAbstractionCoupling|ClassFanOutComplexity)"
               files="(ReplicaFetcherThreadBenchmark).java"/>
 
 </suppressions>
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
index 0843fe8164..ddc45693f8 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
@@ -68,7 +68,7 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
         if (id.fetcherId >= newSize)
           thread.shutdown()
         partitionStates.forKeyValue { (topicPartition, currentFetchState) =>
-            val initialFetchState = InitialFetchState(currentFetchState.topicId, thread.sourceBroker,
+            val initialFetchState = InitialFetchState(currentFetchState.topicId, thread.leader.brokerEndPoint(),
               currentLeaderEpoch = currentFetchState.currentLeaderEpoch,
               initOffset = currentFetchState.fetchOffset)
             allRemovedPartitionsMap += topicPartition -> initialFetchState
@@ -139,7 +139,7 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
       for ((brokerAndFetcherId, initialFetchOffsets) <- partitionsPerFetcher) {
         val brokerIdAndFetcherId = BrokerIdAndFetcherId(brokerAndFetcherId.broker.id, brokerAndFetcherId.fetcherId)
         val fetcherThread = fetcherThreadMap.get(brokerIdAndFetcherId) match {
-          case Some(currentFetcherThread) if currentFetcherThread.sourceBroker == brokerAndFetcherId.broker =>
+          case Some(currentFetcherThread) if currentFetcherThread.leader.brokerEndPoint() == brokerAndFetcherId.broker =>
             // reuse the fetcher thread
             currentFetcherThread
           case Some(f) =>
@@ -163,7 +163,7 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
   protected def addPartitionsToFetcherThread(fetcherThread: T,
                                              initialOffsetAndEpochs: collection.Map[TopicPartition, InitialFetchState]): Unit = {
     fetcherThread.addPartitions(initialOffsetAndEpochs)
-    info(s"Added fetcher to broker ${fetcherThread.sourceBroker.id} for partitions $initialOffsetAndEpochs")
+    info(s"Added fetcher to broker ${fetcherThread.leader.brokerEndPoint().id} for partitions $initialOffsetAndEpochs")
   }
 
   /**
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
index 8aa3f1e032..2ae3f45023 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
@@ -17,7 +17,6 @@
 
 package kafka.server
 
-import kafka.cluster.BrokerEndPoint
 import kafka.common.ClientIdAndBroker
 import kafka.log.LogAppendInfo
 import kafka.metrics.KafkaMetricsGroup
@@ -51,7 +50,7 @@ import scala.math._
  */
 abstract class AbstractFetcherThread(name: String,
                                      clientId: String,
-                                     val sourceBroker: BrokerEndPoint,
+                                     val leader: LeaderEndPoint,
                                      failedPartitions: FailedPartitions,
                                      fetchBackOffMs: Int = 0,
                                      isInterruptible: Boolean = true,
@@ -65,7 +64,7 @@ abstract class AbstractFetcherThread(name: String,
   protected val partitionMapLock = new ReentrantLock
   private val partitionMapCond = partitionMapLock.newCondition()
 
-  private val metricId = ClientIdAndBroker(clientId, sourceBroker.host, sourceBroker.port)
+  private val metricId = ClientIdAndBroker(clientId, leader.brokerEndPoint().host, leader.brokerEndPoint().port)
   val fetcherStats = new FetcherStats(metricId)
   val fetcherLagStats = new FetcherLagStats(metricId)
 
@@ -80,8 +79,6 @@ abstract class AbstractFetcherThread(name: String,
 
   protected def truncateFullyAndStartAt(topicPartition: TopicPartition, offset: Long): Unit
 
-  protected def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]]
-
   protected def latestEpoch(topicPartition: TopicPartition): Option[Int]
 
   protected def logStartOffset(topicPartition: TopicPartition): Long
@@ -90,18 +87,8 @@ abstract class AbstractFetcherThread(name: String,
 
   protected def endOffsetForEpoch(topicPartition: TopicPartition, epoch: Int): Option[OffsetAndEpoch]
 
-  protected def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset]
-
-  protected def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData]
-
-  protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long
-
-  protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long
-
   protected val isOffsetForLeaderEpochSupported: Boolean
 
-  protected val isTruncationOnFetchSupported: Boolean
-
   override def shutdown(): Unit = {
     initiateShutdown()
     inLock(partitionMapLock) {
@@ -121,7 +108,7 @@ abstract class AbstractFetcherThread(name: String,
 
   private def maybeFetch(): Unit = {
     val fetchRequestOpt = inLock(partitionMapLock) {
-      val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = buildFetch(partitionStates.partitionStateMap.asScala)
+      val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = leader.buildFetch(partitionStates.partitionStateMap.asScala)
 
       handlePartitionsWithErrors(partitionsWithError, "maybeFetch")
 
@@ -209,7 +196,7 @@ abstract class AbstractFetcherThread(name: String,
     *   occur during truncation.
     */
   private def truncateToEpochEndOffsets(latestEpochsForPartitions: Map[TopicPartition, EpochData]): Unit = {
-    val endOffsets = fetchEpochEndOffsets(latestEpochsForPartitions)
+    val endOffsets = leader.fetchEpochEndOffsets(latestEpochsForPartitions)
     //Ensure we hold a lock during truncation.
     inLock(partitionMapLock) {
       //Check no leadership and no leader epoch changes happened whilst we were unlocked, fetching epochs
@@ -319,7 +306,7 @@ abstract class AbstractFetcherThread(name: String,
 
     try {
       trace(s"Sending fetch request $fetchRequest")
-      responseData = fetchFromLeader(fetchRequest)
+      responseData = leader.fetch(fetchRequest)
     } catch {
       case t: Throwable =>
         if (isRunning) {
@@ -364,7 +351,7 @@ abstract class AbstractFetcherThread(name: String,
                         fetcherStats.byteRate.mark(validBytes)
                       }
                     }
-                    if (isTruncationOnFetchSupported) {
+                    if (leader.isTruncationOnFetchSupported) {
                       FetchResponse.divergingEpoch(partitionData).ifPresent { divergingEpoch =>
                         divergingEndOffsets += topicPartition -> new EpochEndOffset()
                           .setPartition(topicPartition.partition)
@@ -482,7 +469,7 @@ abstract class AbstractFetcherThread(name: String,
       currentState
     } else if (initialFetchState.initOffset < 0) {
       fetchOffsetAndTruncate(tp, initialFetchState.topicId, initialFetchState.currentLeaderEpoch)
-    } else if (isTruncationOnFetchSupported) {
+    } else if (leader.isTruncationOnFetchSupported) {
       // With old message format, `latestEpoch` will be empty and we use Truncating state
       // to truncate to high watermark.
       val lastFetchedEpoch = latestEpoch(tp)
@@ -537,7 +524,7 @@ abstract class AbstractFetcherThread(name: String,
         val maybeTruncationComplete = fetchOffsets.get(topicPartition) match {
           case Some(offsetTruncationState) =>
             val lastFetchedEpoch = latestEpoch(topicPartition)
-            val state = if (isTruncationOnFetchSupported || offsetTruncationState.truncationCompleted)
+            val state = if (leader.isTruncationOnFetchSupported || offsetTruncationState.truncationCompleted)
               Fetching
             else
               Truncating
@@ -669,7 +656,7 @@ abstract class AbstractFetcherThread(name: String,
      *
      * There is a potential for a mismatch between the logs of the two replicas here. We don't fix this mismatch as of now.
      */
-    val leaderEndOffset = fetchLatestOffsetFromLeader(topicPartition, currentLeaderEpoch)
+    val leaderEndOffset = leader.fetchLatestOffset(topicPartition, currentLeaderEpoch)
     if (leaderEndOffset < replicaEndOffset) {
       warn(s"Reset fetch offset for partition $topicPartition from $replicaEndOffset to current " +
         s"leader's latest offset $leaderEndOffset")
@@ -700,7 +687,7 @@ abstract class AbstractFetcherThread(name: String,
        * Putting the two cases together, the follower should fetch from the higher one of its replica log end offset
        * and the current leader's log start offset.
        */
-      val leaderStartOffset = fetchEarliestOffsetFromLeader(topicPartition, currentLeaderEpoch)
+      val leaderStartOffset = leader.fetchEarliestOffset(topicPartition, currentLeaderEpoch)
       warn(s"Reset fetch offset for partition $topicPartition from $replicaEndOffset to current " +
         s"leader's start offset $leaderStartOffset")
       val offsetToFetch = Math.max(leaderStartOffset, replicaEndOffset)
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherBlockingSend.scala b/core/src/main/scala/kafka/server/BrokerBlockingSender.scala
similarity index 87%
rename from core/src/main/scala/kafka/server/ReplicaFetcherBlockingSend.scala
rename to core/src/main/scala/kafka/server/BrokerBlockingSender.scala
index fd69b5a8a2..7d9fb0512a 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherBlockingSend.scala
+++ b/core/src/main/scala/kafka/server/BrokerBlockingSender.scala
@@ -33,6 +33,8 @@ import scala.jdk.CollectionConverters._
 
 trait BlockingSend {
 
+  def brokerEndPoint(): BrokerEndPoint
+
   def sendRequest(requestBuilder: AbstractRequest.Builder[_ <: AbstractRequest]): ClientResponse
 
   def initiateClose(): Unit
@@ -40,13 +42,13 @@ trait BlockingSend {
   def close(): Unit
 }
 
-class ReplicaFetcherBlockingSend(sourceBroker: BrokerEndPoint,
-                                 brokerConfig: KafkaConfig,
-                                 metrics: Metrics,
-                                 time: Time,
-                                 fetcherId: Int,
-                                 clientId: String,
-                                 logContext: LogContext) extends BlockingSend {
+class BrokerBlockingSender(sourceBroker: BrokerEndPoint,
+                           brokerConfig: KafkaConfig,
+                           metrics: Metrics,
+                           time: Time,
+                           fetcherId: Int,
+                           clientId: String,
+                           logContext: LogContext) extends BlockingSend {
 
   private val sourceNode = new Node(sourceBroker.id, sourceBroker.host, sourceBroker.port)
   private val socketTimeout: Int = brokerConfig.replicaSocketTimeoutMs
@@ -99,6 +101,8 @@ class ReplicaFetcherBlockingSend(sourceBroker: BrokerEndPoint,
     (networkClient, reconfigurableChannelBuilder)
   }
 
+  override def brokerEndPoint(): BrokerEndPoint = sourceBroker
+
   override def sendRequest(requestBuilder: Builder[_ <: AbstractRequest]): ClientResponse = {
     try {
       if (!NetworkClientUtils.awaitReady(networkClient, sourceNode, time, socketTimeout))
@@ -124,4 +128,8 @@ class ReplicaFetcherBlockingSend(sourceBroker: BrokerEndPoint,
   def close(): Unit = {
     networkClient.close()
   }
+
+  override def toString: String = {
+    s"BrokerBlockingSender(sourceBroker=$sourceBroker, fetcherId=$fetcherId)"
+  }
 }
diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala
index eb545cea96..3ed2e3e54d 100755
--- a/core/src/main/scala/kafka/server/KafkaConfig.scala
+++ b/core/src/main/scala/kafka/server/KafkaConfig.scala
@@ -20,7 +20,6 @@ package kafka.server
 import java.util
 import java.util.concurrent.TimeUnit
 import java.util.{Collections, Locale, Properties}
-
 import kafka.cluster.EndPoint
 import kafka.coordinator.group.OffsetConfig
 import kafka.coordinator.transaction.{TransactionLog, TransactionStateManager}
@@ -1793,6 +1792,37 @@ class KafkaConfig private(doLog: Boolean, val props: java.util.Map[_, _], dynami
   val interBrokerProtocolVersionString = getString(KafkaConfig.InterBrokerProtocolVersionProp)
   val interBrokerProtocolVersion = MetadataVersion.fromVersionString(interBrokerProtocolVersionString)
 
+  val fetchRequestVersion: Short =
+    if (interBrokerProtocolVersion.isAtLeast(IBP_3_1_IV0)) 13
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_7_IV1)) 12
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_3_IV1)) 11
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_1_IV2)) 10
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_0_IV1)) 8
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_1_1_IV0)) 7
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_0_11_0_IV1)) 5
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_0_11_0_IV0)) 4
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_0_10_1_IV1)) 3
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_0_10_0_IV0)) 2
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_0_9_0)) 1
+    else 0
+
+  val offsetForLeaderEpochRequestVersion: Short =
+    if (interBrokerProtocolVersion.isAtLeast(IBP_2_8_IV0)) 4
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_3_IV1)) 3
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_1_IV1)) 2
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_0_IV0)) 1
+    else 0
+
+  val listOffsetRequestVersion: Short =
+    if (interBrokerProtocolVersion.isAtLeast(IBP_3_0_IV1)) 7
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_8_IV0)) 6
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_2_IV1)) 5
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_1_IV1)) 4
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_2_0_IV1)) 3
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_0_11_0_IV0)) 2
+    else if (interBrokerProtocolVersion.isAtLeast(IBP_0_10_1_IV2)) 1
+    else 0
+
   /** ********* Controlled shutdown configuration ***********/
   val controlledShutdownMaxRetries = getInt(KafkaConfig.ControlledShutdownMaxRetriesProp)
   val controlledShutdownRetryBackoffMs = getLong(KafkaConfig.ControlledShutdownRetryBackoffMsProp)
diff --git a/core/src/main/scala/kafka/server/LeaderEndPoint.scala b/core/src/main/scala/kafka/server/LeaderEndPoint.scala
new file mode 100644
index 0000000000..70d2149dab
--- /dev/null
+++ b/core/src/main/scala/kafka/server/LeaderEndPoint.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.server
+
+import kafka.cluster.BrokerEndPoint
+import kafka.server.AbstractFetcherThread.{ReplicaFetch, ResultWithPartitions}
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.requests.FetchRequest
+import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
+import org.apache.kafka.common.message.{FetchResponseData, OffsetForLeaderEpochRequestData}
+
+import scala.collection.Map
+
+/**
+ * This trait defines the APIs to be used to access a broker that is a leader.
+ */
+trait LeaderEndPoint {
+
+  type FetchData = FetchResponseData.PartitionData
+  type EpochData = OffsetForLeaderEpochRequestData.OffsetForLeaderPartition
+
+  /**
+   * A boolean specifying if truncation when fetching from the leader is supported
+   */
+  def isTruncationOnFetchSupported: Boolean
+
+  /**
+   * Initiate closing access to fetches from leader.
+   */
+  def initiateClose(): Unit
+
+  /**
+   * Closes access to fetches from leader.
+   * `initiateClose` must be called prior to invoking `close`.
+   */
+  def close(): Unit
+
+  /**
+   * The specific broker (host:port) we want to connect to.
+   */
+  def brokerEndPoint(): BrokerEndPoint
+
+  /**
+   * Given a fetchRequest, carries out the expected request and returns
+   * the results from fetching from the leader.
+   *
+   * @param fetchRequest The fetch request we want to carry out
+   *
+   * @return A map of topic partition -> fetch data
+   */
+  def fetch(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData]
+
+  /**
+   * Fetches the log start offset of the given topic partition from the leader.
+   *
+   * @param topicPartition The topic partition that we want to fetch from
+   * @param currentLeaderEpoch An int representing the current leader epoch of the requester
+   *
+   * @return A long representing the earliest offset in the leader's topic partition.
+   */
+  def fetchEarliestOffset(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long
+
+  /**
+   * Fetches the log end offset of the given topic partition from the leader.
+   *
+   * @param topicPartition The topic partition that we want to fetch from
+   * @param currentLeaderEpoch An int representing the current leader epoch of the requester
+   *
+   * @return A long representing the latest offset in the leader's topic partition.
+   */
+  def fetchLatestOffset(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long
+
+  /**
+   * Fetches offset for leader epoch from the leader for each given topic partition
+   *
+   * @param partitions A map of topic partition -> leader epoch of the replica
+   *
+   * @return A map of topic partition -> end offset for a requested leader epoch
+   */
+  def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset]
+
+  /**
+   * Builds a fetch request, given a partition map.
+   *
+   * @param partitions A map of topic partitions to their respective partition fetch state
+   *
+   * @return A ResultWithPartitions, used to create the fetchRequest for fetch.
+   */
+  def buildFetch(partitions: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]]
+
+}
diff --git a/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala b/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala
new file mode 100644
index 0000000000..1080c8e073
--- /dev/null
+++ b/core/src/main/scala/kafka/server/LocalLeaderEndPoint.scala
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.server
+
+import kafka.api.Request
+import kafka.cluster.BrokerEndPoint
+import kafka.server.AbstractFetcherThread.{ReplicaFetch, ResultWithPartitions}
+import kafka.server.QuotaFactory.UnboundedQuota
+import kafka.utils.Logging
+import org.apache.kafka.common.errors.KafkaStorageException
+import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
+import org.apache.kafka.common.message.FetchResponseData
+import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH
+import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
+import org.apache.kafka.common.protocol.{ApiKeys, Errors}
+import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, RequestUtils}
+
+import java.util
+import java.util.Optional
+import scala.collection.{Map, Seq, Set, mutable}
+import scala.compat.java8.OptionConverters.RichOptionForJava8
+import scala.jdk.CollectionConverters._
+
+/**
+ * Facilitates fetches from a local replica leader.
+ *
+ * @param sourceBroker The broker (host:port) that we want to connect to
+ * @param brokerConfig A config file with broker related configurations
+ * @param replicaManager A ReplicaManager
+ * @param quota The quota, used when building a fetch request
+ */
+class LocalLeaderEndPoint(sourceBroker: BrokerEndPoint,
+                          brokerConfig: KafkaConfig,
+                          replicaManager: ReplicaManager,
+                          quota: ReplicaQuota) extends LeaderEndPoint with Logging {
+
+  private val replicaId = brokerConfig.brokerId
+  private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes
+  private val fetchSize = brokerConfig.replicaFetchMaxBytes
+  private var inProgressPartition: Option[TopicPartition] = None
+
+  override val isTruncationOnFetchSupported: Boolean = false
+
+  override def initiateClose(): Unit = {} // do nothing
+
+  override def close(): Unit = {} // do nothing
+
+  override def brokerEndPoint(): BrokerEndPoint = sourceBroker
+
+  override def fetch(fetchRequest: FetchRequest.Builder): collection.Map[TopicPartition, FetchData] = {
+    var partitionData: Seq[(TopicPartition, FetchData)] = null
+    val request = fetchRequest.build()
+
+    // We can build the map from the request since it contains topic IDs and names.
+    // Only one ID can be associated with a name and vice versa.
+    val topicNames = new mutable.HashMap[Uuid, String]()
+    request.data.topics.forEach { topic =>
+      topicNames.put(topic.topicId, topic.topic)
+    }
+
+    def processResponseCallback(responsePartitionData: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
+      partitionData = responsePartitionData.map { case (tp, data) =>
+        val abortedTransactions = data.abortedTransactions.map(_.asJava).orNull
+        val lastStableOffset = data.lastStableOffset.getOrElse(FetchResponse.INVALID_LAST_STABLE_OFFSET)
+        tp.topicPartition -> new FetchResponseData.PartitionData()
+          .setPartitionIndex(tp.topicPartition.partition)
+          .setErrorCode(data.error.code)
+          .setHighWatermark(data.highWatermark)
+          .setLastStableOffset(lastStableOffset)
+          .setLogStartOffset(data.logStartOffset)
+          .setAbortedTransactions(abortedTransactions)
+          .setRecords(data.records)
+      }
+    }
+
+    val fetchData = request.fetchData(topicNames.asJava)
+
+    val fetchParams = FetchParams(
+      requestVersion = request.version,
+      maxWaitMs = 0L, // timeout is 0 so that the callback will be executed immediately
+      replicaId = Request.FutureLocalReplicaId,
+      minBytes = request.minBytes,
+      maxBytes = request.maxBytes,
+      isolation = FetchLogEnd,
+      clientMetadata = None
+    )
+
+    replicaManager.fetchMessages(
+      params = fetchParams,
+      fetchInfos = fetchData.asScala.toSeq,
+      quota = UnboundedQuota,
+      responseCallback = processResponseCallback
+    )
+
+    if (partitionData == null)
+      throw new IllegalStateException(s"Failed to fetch data for partitions ${fetchData.keySet().toArray.mkString(",")}")
+
+    partitionData.toMap
+  }
+
+  override def fetchEarliestOffset(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = {
+    val partition = replicaManager.getPartitionOrException(topicPartition)
+    partition.localLogOrException.logStartOffset
+  }
+
+  override def fetchLatestOffset(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = {
+    val partition = replicaManager.getPartitionOrException(topicPartition)
+    partition.localLogOrException.logEndOffset
+  }
+
+  override def fetchEpochEndOffsets(partitions: collection.Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
+    partitions.map { case (tp, epochData) =>
+      try {
+        val endOffset = if (epochData.leaderEpoch == UNDEFINED_EPOCH) {
+          new EpochEndOffset()
+            .setPartition(tp.partition)
+            .setErrorCode(Errors.NONE.code)
+        } else {
+          val partition = replicaManager.getPartitionOrException(tp)
+          partition.lastOffsetForLeaderEpoch(
+            currentLeaderEpoch = RequestUtils.getLeaderEpoch(epochData.currentLeaderEpoch),
+            leaderEpoch = epochData.leaderEpoch,
+            fetchOnlyFromLeader = false)
+        }
+        tp -> endOffset
+      } catch {
+        case t: Throwable =>
+          warn(s"Error when getting EpochEndOffset for $tp", t)
+          tp -> new EpochEndOffset()
+            .setPartition(tp.partition)
+            .setErrorCode(Errors.forException(t).code)
+      }
+    }
+  }
+
+  override def buildFetch(partitions: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = {
+    // Only include replica in the fetch request if it is not throttled.
+    if (quota.isQuotaExceeded) {
+      ResultWithPartitions(None, Set.empty)
+    } else {
+      selectPartitionToFetch(partitions) match {
+        case Some((tp, fetchState)) =>
+          buildFetchForPartition(tp, fetchState)
+        case None =>
+          ResultWithPartitions(None, Set.empty)
+      }
+    }
+  }
+
+  private def selectPartitionToFetch(partitions: Map[TopicPartition, PartitionFetchState]): Option[(TopicPartition, PartitionFetchState)] = {
+    // Only move one partition at a time to increase its catch-up rate and thus reduce the time spent on
+    // moving any given replica. Replicas are selected in ascending order (lexicographically by topic) from the
+    // partitions that are ready to fetch. Once selected, we will continue fetching the same partition until it
+    // becomes unavailable or is removed.
+
+    inProgressPartition.foreach { tp =>
+      val fetchStateOpt = partitions.get(tp)
+      fetchStateOpt.filter(_.isReadyForFetch).foreach { fetchState =>
+        return Some((tp, fetchState))
+      }
+    }
+
+    inProgressPartition = None
+
+    val nextPartitionOpt = nextReadyPartition(partitions)
+    nextPartitionOpt.foreach { case (tp, fetchState) =>
+      inProgressPartition = Some(tp)
+      info(s"Beginning/resuming copy of partition $tp from offset ${fetchState.fetchOffset}. " +
+        s"Including this partition, there are ${partitions.size} remaining partitions to copy by this thread.")
+    }
+    nextPartitionOpt
+  }
+
+  private def buildFetchForPartition(topicPartition: TopicPartition, fetchState: PartitionFetchState): ResultWithPartitions[Option[ReplicaFetch]] = {
+    val requestMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+    val partitionsWithError = mutable.Set[TopicPartition]()
+
+    try {
+      val logStartOffset = replicaManager.futureLocalLogOrException(topicPartition).logStartOffset
+      val lastFetchedEpoch = if (isTruncationOnFetchSupported)
+        fetchState.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
+      else
+        Optional.empty[Integer]
+      val topicId = fetchState.topicId.getOrElse(Uuid.ZERO_UUID)
+      requestMap.put(topicPartition, new FetchRequest.PartitionData(topicId, fetchState.fetchOffset, logStartOffset,
+        fetchSize, Optional.of(fetchState.currentLeaderEpoch), lastFetchedEpoch))
+    } catch {
+      case e: KafkaStorageException =>
+        debug(s"Failed to build fetch for $topicPartition", e)
+        partitionsWithError += topicPartition
+    }
+
+    val fetchRequestOpt = if (requestMap.isEmpty) {
+      None
+    } else {
+      val version: Short = if (fetchState.topicId.isEmpty)
+        12
+      else
+        ApiKeys.FETCH.latestVersion
+      // Set maxWait and minBytes to 0 because the response should return immediately if
+      // the future log has caught up with the current log of the partition
+      val requestBuilder = FetchRequest.Builder.forReplica(version, replicaId, 0, 0, requestMap).setMaxBytes(maxBytes)
+      Some(ReplicaFetch(requestMap, requestBuilder))
+    }
+
+    ResultWithPartitions(fetchRequestOpt, partitionsWithError)
+  }
+
+  private def nextReadyPartition(partitions: Map[TopicPartition, PartitionFetchState]): Option[(TopicPartition, PartitionFetchState)] = {
+    partitions.filter { case (_, partitionFetchState) =>
+      partitionFetchState.isReadyForFetch
+    }.reduceLeftOption { (left, right) =>
+      if ((left._1.topic < right._1.topic) || (left._1.topic == right._1.topic && left._1.partition < right._1.partition))
+        left
+      else
+        right
+    }
+  }
+
+  override def toString: String = s"LocalLeaderEndPoint"
+}
diff --git a/core/src/main/scala/kafka/server/RemoteLeaderEndPoint.scala b/core/src/main/scala/kafka/server/RemoteLeaderEndPoint.scala
new file mode 100644
index 0000000000..a9ac51315a
--- /dev/null
+++ b/core/src/main/scala/kafka/server/RemoteLeaderEndPoint.scala
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.server
+
+import kafka.cluster.BrokerEndPoint
+
+import java.util.{Collections, Optional}
+import kafka.server.AbstractFetcherThread.{ReplicaFetch, ResultWithPartitions}
+import kafka.utils.Implicits.MapExtensionMethods
+import kafka.utils.Logging
+import org.apache.kafka.clients.FetchSessionHandler
+import org.apache.kafka.common.errors.KafkaStorageException
+import org.apache.kafka.common.{TopicPartition, Uuid}
+import org.apache.kafka.common.message.ListOffsetsRequestData.{ListOffsetsPartition, ListOffsetsTopic}
+import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.{OffsetForLeaderTopic, OffsetForLeaderTopicCollection}
+import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
+import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, ListOffsetsRequest, ListOffsetsResponse, OffsetsForLeaderEpochRequest, OffsetsForLeaderEpochResponse}
+import org.apache.kafka.server.common.MetadataVersion.IBP_0_10_1_IV2
+
+import scala.jdk.CollectionConverters._
+import scala.collection.{Map, mutable}
+import scala.compat.java8.OptionConverters.RichOptionForJava8
+
+/**
+ * Facilitates fetches from a remote replica leader.
+ *
+ * @param logPrefix The log prefix
+ * @param blockingSender The raw leader endpoint used to communicate with the leader
+ * @param fetchSessionHandler A FetchSessionHandler to track the partitions in the session
+ * @param brokerConfig Broker configuration
+ * @param replicaManager A ReplicaManager
+ * @param quota The quota, used when building a fetch request
+ */
+class RemoteLeaderEndPoint(logPrefix: String,
+                           blockingSender: BlockingSend,
+                           private[server] val fetchSessionHandler: FetchSessionHandler, // visible for testing
+                           brokerConfig: KafkaConfig,
+                           replicaManager: ReplicaManager,
+                           quota: ReplicaQuota) extends LeaderEndPoint with Logging {
+
+  this.logIdent = logPrefix
+
+  private val maxWait = brokerConfig.replicaFetchWaitMaxMs
+  private val minBytes = brokerConfig.replicaFetchMinBytes
+  private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes
+  private val fetchSize = brokerConfig.replicaFetchMaxBytes
+
+  override val isTruncationOnFetchSupported = brokerConfig.interBrokerProtocolVersion.isTruncationOnFetchSupported
+
+  override def initiateClose(): Unit = blockingSender.initiateClose()
+
+  override def close(): Unit = blockingSender.close()
+
+  override def brokerEndPoint(): BrokerEndPoint = blockingSender.brokerEndPoint()
+
+  override def fetch(fetchRequest: FetchRequest.Builder): collection.Map[TopicPartition, FetchData] = {
+    val clientResponse = try {
+      blockingSender.sendRequest(fetchRequest)
+    } catch {
+      case t: Throwable =>
+        fetchSessionHandler.handleError(t)
+        throw t
+    }
+    val fetchResponse = clientResponse.responseBody.asInstanceOf[FetchResponse]
+    if (!fetchSessionHandler.handleResponse(fetchResponse, clientResponse.requestHeader().apiVersion())) {
+      // If we had a session topic ID related error, throw it, otherwise return an empty fetch data map.
+      if (fetchResponse.error == Errors.FETCH_SESSION_TOPIC_ID_ERROR) {
+        throw Errors.forCode(fetchResponse.error().code()).exception()
+      } else {
+        Map.empty
+      }
+    } else {
+      fetchResponse.responseData(fetchSessionHandler.sessionTopicNames, clientResponse.requestHeader().apiVersion()).asScala
+    }
+  }
+
+  override def fetchEarliestOffset(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = {
+    fetchOffset(topicPartition, currentLeaderEpoch, ListOffsetsRequest.EARLIEST_TIMESTAMP)
+  }
+
+  override def fetchLatestOffset(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = {
+    fetchOffset(topicPartition, currentLeaderEpoch, ListOffsetsRequest.LATEST_TIMESTAMP)
+  }
+
+  private def fetchOffset(topicPartition: TopicPartition, currentLeaderEpoch: Int, earliestOrLatest: Long): Long = {
+    val topic = new ListOffsetsTopic()
+      .setName(topicPartition.topic)
+      .setPartitions(Collections.singletonList(
+        new ListOffsetsPartition()
+          .setPartitionIndex(topicPartition.partition)
+          .setCurrentLeaderEpoch(currentLeaderEpoch)
+          .setTimestamp(earliestOrLatest)))
+    val requestBuilder = ListOffsetsRequest.Builder.forReplica(brokerConfig.listOffsetRequestVersion, brokerConfig.brokerId)
+      .setTargetTimes(Collections.singletonList(topic))
+
+    val clientResponse = blockingSender.sendRequest(requestBuilder)
+    val response = clientResponse.responseBody.asInstanceOf[ListOffsetsResponse]
+    val responsePartition = response.topics.asScala.find(_.name == topicPartition.topic).get
+      .partitions.asScala.find(_.partitionIndex == topicPartition.partition).get
+
+    Errors.forCode(responsePartition.errorCode) match {
+      case Errors.NONE =>
+        if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_0_10_1_IV2))
+          responsePartition.offset
+        else
+          responsePartition.oldStyleOffsets.get(0)
+      case error => throw error.exception
+    }
+  }
+
+  override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
+    if (partitions.isEmpty) {
+      debug("Skipping leaderEpoch request since all partitions do not have an epoch")
+      return Map.empty
+    }
+
+    val topics = new OffsetForLeaderTopicCollection(partitions.size)
+    partitions.forKeyValue { (topicPartition, epochData) =>
+      var topic = topics.find(topicPartition.topic)
+      if (topic == null) {
+        topic = new OffsetForLeaderTopic().setTopic(topicPartition.topic)
+        topics.add(topic)
+      }
+      topic.partitions.add(epochData)
+    }
+
+    val epochRequest = OffsetsForLeaderEpochRequest.Builder.forFollower(
+      brokerConfig.offsetForLeaderEpochRequestVersion, topics, brokerConfig.brokerId)
+    debug(s"Sending offset for leader epoch request $epochRequest")
+
+    try {
+      val response = blockingSender.sendRequest(epochRequest)
+      val responseBody = response.responseBody.asInstanceOf[OffsetsForLeaderEpochResponse]
+      debug(s"Received leaderEpoch response $response")
+      responseBody.data.topics.asScala.flatMap { offsetForLeaderTopicResult =>
+        offsetForLeaderTopicResult.partitions.asScala.map { offsetForLeaderPartitionResult =>
+          val tp = new TopicPartition(offsetForLeaderTopicResult.topic, offsetForLeaderPartitionResult.partition)
+          tp -> offsetForLeaderPartitionResult
+        }
+      }.toMap
+    } catch {
+      case t: Throwable =>
+        warn(s"Error when sending leader epoch request for $partitions", t)
+
+        // if we get any unexpected exception, mark all partitions with an error
+        val error = Errors.forException(t)
+        partitions.map { case (tp, _) =>
+          tp -> new EpochEndOffset()
+            .setPartition(tp.partition)
+            .setErrorCode(error.code)
+        }
+    }
+  }
+
+  override def buildFetch(partitions: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = {
+    val partitionsWithError = mutable.Set[TopicPartition]()
+
+    val builder = fetchSessionHandler.newBuilder(partitions.size, false)
+    partitions.forKeyValue { (topicPartition, fetchState) =>
+      // We will not include a replica in the fetch request if it should be throttled.
+      if (fetchState.isReadyForFetch && !shouldFollowerThrottle(quota, fetchState, topicPartition)) {
+        try {
+          val logStartOffset = replicaManager.localLogOrException(topicPartition).logStartOffset
+          val lastFetchedEpoch = if (isTruncationOnFetchSupported)
+            fetchState.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
+          else
+            Optional.empty[Integer]
+          builder.add(topicPartition, new FetchRequest.PartitionData(
+            fetchState.topicId.getOrElse(Uuid.ZERO_UUID),
+            fetchState.fetchOffset,
+            logStartOffset,
+            fetchSize,
+            Optional.of(fetchState.currentLeaderEpoch),
+            lastFetchedEpoch))
+        } catch {
+          case _: KafkaStorageException =>
+            // The replica has already been marked offline due to log directory failure and the original failure should have already been logged.
+            // This partition should be removed from ReplicaFetcherThread soon by ReplicaManager.handleLogDirFailure()
+            partitionsWithError += topicPartition
+        }
+      }
+    }
+
+    val fetchData = builder.build()
+    val fetchRequestOpt = if (fetchData.sessionPartitions.isEmpty && fetchData.toForget.isEmpty) {
+      None
+    } else {
+      val version: Short = if (brokerConfig.fetchRequestVersion >= 13 && !fetchData.canUseTopicIds) 12 else brokerConfig.fetchRequestVersion
+      val requestBuilder = FetchRequest.Builder
+        .forReplica(version, brokerConfig.brokerId, maxWait, minBytes, fetchData.toSend)
+        .setMaxBytes(maxBytes)
+        .removed(fetchData.toForget)
+        .replaced(fetchData.toReplace)
+        .metadata(fetchData.metadata)
+      Some(ReplicaFetch(fetchData.sessionPartitions(), requestBuilder))
+    }
+
+    ResultWithPartitions(fetchRequestOpt, partitionsWithError)
+  }
+
+  /**
+   *  To avoid ISR thrashing, we only throttle a replica on the follower if it's in the throttled replica list,
+   *  the quota is exceeded and the replica is not in sync.
+   */
+  private def shouldFollowerThrottle(quota: ReplicaQuota, fetchState: PartitionFetchState, topicPartition: TopicPartition): Boolean = {
+    !fetchState.isReplicaInSync && quota.isThrottled(topicPartition) && quota.isQuotaExceeded
+  }
+
+  override def toString: String = s"RemoteLeaderEndPoint(blockingSender=$blockingSender)"
+}
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
index b45a76620c..0613449e07 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
@@ -31,8 +31,9 @@ class ReplicaAlterLogDirsManager(brokerConfig: KafkaConfig,
 
   override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): ReplicaAlterLogDirsThread = {
     val threadName = s"ReplicaAlterLogDirsThread-$fetcherId"
-    new ReplicaAlterLogDirsThread(threadName, sourceBroker, brokerConfig, failedPartitions, replicaManager,
-      quotaManager, brokerTopicStats)
+    val leader = new LocalLeaderEndPoint(sourceBroker, brokerConfig, replicaManager, quotaManager)
+    new ReplicaAlterLogDirsThread(threadName, leader, failedPartitions, replicaManager,
+      quotaManager, brokerTopicStats, brokerConfig.replicaFetchBackoffMs)
   }
 
   override protected def addPartitionsToFetcherThread(fetcherThread: ReplicaAlterLogDirsThread,
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
index 4a6a6e070c..10eae83b99 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
@@ -17,44 +17,27 @@
 
 package kafka.server
 
-import kafka.api.Request
-import kafka.cluster.BrokerEndPoint
 import kafka.log.{LeaderOffsetIncremented, LogAppendInfo}
-import kafka.server.AbstractFetcherThread.{ReplicaFetch, ResultWithPartitions}
-import kafka.server.QuotaFactory.UnboundedQuota
-import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
-import org.apache.kafka.common.errors.KafkaStorageException
-import org.apache.kafka.common.message.FetchResponseData
-import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
-import org.apache.kafka.common.protocol.{ApiKeys, Errors}
-import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH
-import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, RequestUtils}
-import java.util
-import java.util.Optional
-import scala.collection.{Map, Seq, Set, mutable}
-import scala.compat.java8.OptionConverters._
-import scala.jdk.CollectionConverters._
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.requests.FetchResponse
+
+import scala.collection.{Map, Set}
 
 class ReplicaAlterLogDirsThread(name: String,
-                                sourceBroker: BrokerEndPoint,
-                                brokerConfig: KafkaConfig,
+                                leader: LeaderEndPoint,
                                 failedPartitions: FailedPartitions,
                                 replicaMgr: ReplicaManager,
                                 quota: ReplicationQuotaManager,
-                                brokerTopicStats: BrokerTopicStats)
+                                brokerTopicStats: BrokerTopicStats,
+                                fetchBackOffMs: Int)
   extends AbstractFetcherThread(name = name,
                                 clientId = name,
-                                sourceBroker = sourceBroker,
+                                leader = leader,
                                 failedPartitions,
-                                fetchBackOffMs = brokerConfig.replicaFetchBackoffMs,
+                                fetchBackOffMs = fetchBackOffMs,
                                 isInterruptible = false,
                                 brokerTopicStats) {
 
-  private val replicaId = brokerConfig.brokerId
-  private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes
-  private val fetchSize = brokerConfig.replicaFetchMaxBytes
-  private var inProgressPartition: Option[TopicPartition] = None
-
   override protected def latestEpoch(topicPartition: TopicPartition): Option[Int] = {
     replicaMgr.futureLocalLogOrException(topicPartition).latestEpoch
   }
@@ -71,58 +54,6 @@ class ReplicaAlterLogDirsThread(name: String,
     replicaMgr.futureLocalLogOrException(topicPartition).endOffsetForEpoch(epoch)
   }
 
-  def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
-    var partitionData: Seq[(TopicPartition, FetchData)] = null
-    val request = fetchRequest.build()
-
-    // We can build the map from the request since it contains topic IDs and names.
-    // Only one ID can be associated with a name and vice versa.
-    val topicNames = new mutable.HashMap[Uuid, String]()
-    request.data.topics.forEach { topic =>
-      topicNames.put(topic.topicId, topic.topic)
-    }
-
-
-    def processResponseCallback(responsePartitionData: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
-      partitionData = responsePartitionData.map { case (tp, data) =>
-        val abortedTransactions = data.abortedTransactions.map(_.asJava).orNull
-        val lastStableOffset = data.lastStableOffset.getOrElse(FetchResponse.INVALID_LAST_STABLE_OFFSET)
-        tp.topicPartition -> new FetchResponseData.PartitionData()
-          .setPartitionIndex(tp.topicPartition.partition)
-          .setErrorCode(data.error.code)
-          .setHighWatermark(data.highWatermark)
-          .setLastStableOffset(lastStableOffset)
-          .setLogStartOffset(data.logStartOffset)
-          .setAbortedTransactions(abortedTransactions)
-          .setRecords(data.records)
-      }
-    }
-
-    val fetchData = request.fetchData(topicNames.asJava)
-
-    val fetchParams = FetchParams(
-      requestVersion = request.version,
-      maxWaitMs = 0L, // timeout is 0 so that the callback will be executed immediately
-      replicaId = Request.FutureLocalReplicaId,
-      minBytes = request.minBytes,
-      maxBytes = request.maxBytes,
-      isolation = FetchLogEnd,
-      clientMetadata = None
-    )
-
-    replicaMgr.fetchMessages(
-      params = fetchParams,
-      fetchInfos = fetchData.asScala.toSeq,
-      quota = UnboundedQuota,
-      responseCallback = processResponseCallback
-    )
-
-    if (partitionData == null)
-      throw new IllegalStateException(s"Failed to fetch data for partitions ${fetchData.keySet().toArray.mkString(",")}")
-
-    partitionData.toMap
-  }
-
   // process fetched data
   override def processPartitionData(topicPartition: TopicPartition,
                                     fetchOffset: Long,
@@ -164,50 +95,8 @@ class ReplicaAlterLogDirsThread(name: String,
     }
   }
 
-  override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
-    val partition = replicaMgr.getPartitionOrException(topicPartition)
-    partition.localLogOrException.logStartOffset
-  }
-
-  override protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
-    val partition = replicaMgr.getPartitionOrException(topicPartition)
-    partition.localLogOrException.logEndOffset
-  }
-
-  /**
-   * Fetches offset for leader epoch from local replica for each given topic partitions
-   * @param partitions map of topic partition -> leader epoch of the future replica
-   * @return map of topic partition -> end offset for a requested leader epoch
-   */
-  override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
-    partitions.map { case (tp, epochData) =>
-      try {
-        val endOffset = if (epochData.leaderEpoch == UNDEFINED_EPOCH) {
-          new EpochEndOffset()
-            .setPartition(tp.partition)
-            .setErrorCode(Errors.NONE.code)
-        } else {
-          val partition = replicaMgr.getPartitionOrException(tp)
-          partition.lastOffsetForLeaderEpoch(
-            currentLeaderEpoch = RequestUtils.getLeaderEpoch(epochData.currentLeaderEpoch),
-            leaderEpoch = epochData.leaderEpoch,
-            fetchOnlyFromLeader = false)
-        }
-        tp -> endOffset
-      } catch {
-        case t: Throwable =>
-          warn(s"Error when getting EpochEndOffset for $tp", t)
-          tp -> new EpochEndOffset()
-            .setPartition(tp.partition)
-            .setErrorCode(Errors.forException(t).code)
-      }
-    }
-  }
-
   override protected val isOffsetForLeaderEpochSupported: Boolean = true
 
-  override protected val isTruncationOnFetchSupported: Boolean = false
-
   /**
    * Truncate the log for each partition based on current replica's returned epoch and offset.
    *
@@ -232,88 +121,4 @@ class ReplicaAlterLogDirsThread(name: String,
     partition.truncateFullyAndStartAt(offset, isFuture = true)
   }
 
-  private def nextReadyPartition(partitionMap: Map[TopicPartition, PartitionFetchState]): Option[(TopicPartition, PartitionFetchState)] = {
-    partitionMap.filter { case (_, partitionFetchState) =>
-      partitionFetchState.isReadyForFetch
-    }.reduceLeftOption { (left, right) =>
-      if ((left._1.topic < right._1.topic) || (left._1.topic == right._1.topic && left._1.partition < right._1.partition))
-        left
-      else
-        right
-    }
-  }
-
-  private def selectPartitionToFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): Option[(TopicPartition, PartitionFetchState)] = {
-    // Only move one partition at a time to increase its catch-up rate and thus reduce the time spent on
-    // moving any given replica. Replicas are selected in ascending order (lexicographically by topic) from the
-    // partitions that are ready to fetch. Once selected, we will continue fetching the same partition until it
-    // becomes unavailable or is removed.
-
-    inProgressPartition.foreach { tp =>
-      val fetchStateOpt = partitionMap.get(tp)
-      fetchStateOpt.filter(_.isReadyForFetch).foreach { fetchState =>
-        return Some((tp, fetchState))
-      }
-    }
-
-    inProgressPartition = None
-
-    val nextPartitionOpt = nextReadyPartition(partitionMap)
-    nextPartitionOpt.foreach { case (tp, fetchState) =>
-      inProgressPartition = Some(tp)
-      info(s"Beginning/resuming copy of partition $tp from offset ${fetchState.fetchOffset}. " +
-        s"Including this partition, there are ${partitionMap.size} remaining partitions to copy by this thread.")
-    }
-    nextPartitionOpt
-  }
-
-  private def buildFetchForPartition(tp: TopicPartition, fetchState: PartitionFetchState): ResultWithPartitions[Option[ReplicaFetch]] = {
-    val requestMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    val partitionsWithError = mutable.Set[TopicPartition]()
-
-    try {
-      val logStartOffset = replicaMgr.futureLocalLogOrException(tp).logStartOffset
-      val lastFetchedEpoch = if (isTruncationOnFetchSupported)
-        fetchState.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
-      else
-        Optional.empty[Integer]
-      val topicId = fetchState.topicId.getOrElse(Uuid.ZERO_UUID)
-      requestMap.put(tp, new FetchRequest.PartitionData(topicId, fetchState.fetchOffset, logStartOffset,
-        fetchSize, Optional.of(fetchState.currentLeaderEpoch), lastFetchedEpoch))
-    } catch {
-      case e: KafkaStorageException =>
-        debug(s"Failed to build fetch for $tp", e)
-        partitionsWithError += tp
-    }
-
-    val fetchRequestOpt = if (requestMap.isEmpty) {
-      None
-    } else {
-      val version: Short = if (fetchState.topicId.isEmpty)
-        12
-      else
-        ApiKeys.FETCH.latestVersion
-      // Set maxWait and minBytes to 0 because the response should return immediately if
-      // the future log has caught up with the current log of the partition
-      val requestBuilder = FetchRequest.Builder.forReplica(version, replicaId, 0, 0, requestMap).setMaxBytes(maxBytes)
-      Some(ReplicaFetch(requestMap, requestBuilder))
-    }
-
-    ResultWithPartitions(fetchRequestOpt, partitionsWithError)
-  }
-
-  def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = {
-    // Only include replica in the fetch request if it is not throttled.
-    if (quota.isQuotaExceeded) {
-      ResultWithPartitions(None, Set.empty)
-    } else {
-      selectPartitionToFetch(partitionMap) match {
-        case Some((tp, fetchState)) =>
-          buildFetchForPartition(tp, fetchState)
-        case None =>
-          ResultWithPartitions(None, Set.empty)
-      }
-    }
-  }
-
 }
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala b/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala
index d547e1b5d7..feb082c5ae 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala
@@ -18,8 +18,9 @@
 package kafka.server
 
 import kafka.cluster.BrokerEndPoint
+import org.apache.kafka.clients.FetchSessionHandler
 import org.apache.kafka.common.metrics.Metrics
-import org.apache.kafka.common.utils.Time
+import org.apache.kafka.common.utils.{LogContext, Time}
 
 class ReplicaFetcherManager(brokerConfig: KafkaConfig,
                             protected val replicaManager: ReplicaManager,
@@ -35,8 +36,14 @@ class ReplicaFetcherManager(brokerConfig: KafkaConfig,
   override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): ReplicaFetcherThread = {
     val prefix = threadNamePrefix.map(tp => s"$tp:").getOrElse("")
     val threadName = s"${prefix}ReplicaFetcherThread-$fetcherId-${sourceBroker.id}"
-    new ReplicaFetcherThread(threadName, fetcherId, sourceBroker, brokerConfig, failedPartitions, replicaManager,
-      metrics, time, quotaManager)
+    val logContext = new LogContext(s"[ReplicaFetcher replicaId=${brokerConfig.brokerId}, leaderId=${sourceBroker.id}, " +
+      s"fetcherId=$fetcherId] ")
+    val endpoint = new BrokerBlockingSender(sourceBroker, brokerConfig, metrics, time, fetcherId,
+      s"broker-${brokerConfig.brokerId}-fetcher-$fetcherId", logContext)
+    val fetchSessionHandler = new FetchSessionHandler(logContext, sourceBroker.id)
+    val leader = new RemoteLeaderEndPoint(logContext.logPrefix, endpoint, fetchSessionHandler, brokerConfig, replicaManager, quotaManager)
+    new ReplicaFetcherThread(threadName, leader, brokerConfig, failedPartitions, replicaManager,
+      quotaManager, logContext.logPrefix)
   }
 
   def shutdown(): Unit = {
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index b598c397cf..86cc6b1b9d 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -17,100 +17,29 @@
 
 package kafka.server
 
-import java.util.Collections
-import java.util.Optional
-
-import kafka.cluster.BrokerEndPoint
 import kafka.log.{LeaderOffsetIncremented, LogAppendInfo}
-import kafka.server.AbstractFetcherThread.ReplicaFetch
-import kafka.server.AbstractFetcherThread.ResultWithPartitions
-import kafka.utils.Implicits._
-import org.apache.kafka.clients.FetchSessionHandler
-import org.apache.kafka.common.{TopicPartition, Uuid}
-import org.apache.kafka.common.errors.KafkaStorageException
-import org.apache.kafka.common.message.ListOffsetsRequestData.{ListOffsetsPartition, ListOffsetsTopic}
-import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopic
-import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopicCollection
-import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
-import org.apache.kafka.common.metrics.Metrics
-import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.requests._
-import org.apache.kafka.common.utils.{LogContext, Time}
-import org.apache.kafka.server.common.MetadataVersion._
-
-import scala.jdk.CollectionConverters._
-import scala.collection.{Map, mutable}
-import scala.compat.java8.OptionConverters._
+import org.apache.kafka.common.TopicPartition
 
 class ReplicaFetcherThread(name: String,
-                           fetcherId: Int,
-                           sourceBroker: BrokerEndPoint,
+                           leader: LeaderEndPoint,
                            brokerConfig: KafkaConfig,
                            failedPartitions: FailedPartitions,
                            replicaMgr: ReplicaManager,
-                           metrics: Metrics,
-                           time: Time,
                            quota: ReplicaQuota,
-                           leaderEndpointBlockingSend: Option[BlockingSend] = None)
+                           logPrefix: String)
   extends AbstractFetcherThread(name = name,
                                 clientId = name,
-                                sourceBroker = sourceBroker,
+                                leader = leader,
                                 failedPartitions,
                                 fetchBackOffMs = brokerConfig.replicaFetchBackoffMs,
                                 isInterruptible = false,
                                 replicaMgr.brokerTopicStats) {
 
-  private val replicaId = brokerConfig.brokerId
-  private val logContext = new LogContext(s"[ReplicaFetcher replicaId=$replicaId, leaderId=${sourceBroker.id}, " +
-    s"fetcherId=$fetcherId] ")
-  this.logIdent = logContext.logPrefix
-
-  private val leaderEndpoint = leaderEndpointBlockingSend.getOrElse(
-    new ReplicaFetcherBlockingSend(sourceBroker, brokerConfig, metrics, time, fetcherId,
-      s"broker-$replicaId-fetcher-$fetcherId", logContext))
-
-  // Visible for testing
-  private[server] val fetchRequestVersion: Short =
-    if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_3_1_IV0)) 13
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_7_IV1)) 12
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_3_IV1)) 11
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_1_IV2)) 10
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_0_IV1)) 8
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_1_1_IV0)) 7
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_0_11_0_IV1)) 5
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_0_11_0_IV0)) 4
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_0_10_1_IV1)) 3
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_0_10_0_IV0)) 2
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_0_9_0)) 1
-    else 0
+  this.logIdent = logPrefix
 
-  // Visible for testing
-  private[server] val offsetForLeaderEpochRequestVersion: Short =
-    if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_8_IV0)) 4
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_3_IV1)) 3
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_1_IV1)) 2
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_0_IV0)) 1
-    else 0
-
-  // Visible for testing
-  private[server] val listOffsetRequestVersion: Short =
-    if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_3_0_IV1)) 7
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_8_IV0)) 6
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_2_IV1)) 5
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_1_IV1)) 4
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_2_0_IV1)) 3
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_0_11_0_IV0)) 2
-    else if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_0_10_1_IV2)) 1
-    else 0
-
-  private val maxWait = brokerConfig.replicaFetchWaitMaxMs
-  private val minBytes = brokerConfig.replicaFetchMinBytes
-  private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes
-  private val fetchSize = brokerConfig.replicaFetchMaxBytes
   override protected val isOffsetForLeaderEpochSupported: Boolean = brokerConfig.interBrokerProtocolVersion.isOffsetForLeaderEpochSupported
-  override protected val isTruncationOnFetchSupported = brokerConfig.interBrokerProtocolVersion.isTruncationOnFetchSupported
-  val fetchSessionHandler = new FetchSessionHandler(logContext, sourceBroker.id)
 
   override protected def latestEpoch(topicPartition: TopicPartition): Option[Int] = {
     replicaMgr.localLogOrException(topicPartition).latestEpoch
@@ -135,10 +64,10 @@ class ReplicaFetcherThread(name: String,
       // to avoid failing the caller, especially during shutdown. We will attempt to close
       // leaderEndpoint after the thread terminates.
       try {
-        leaderEndpoint.initiateClose()
+        leader.initiateClose()
       } catch {
         case t: Throwable =>
-          error(s"Failed to initiate shutdown of leader endpoint $leaderEndpoint after initiating replica fetcher thread shutdown", t)
+          error(s"Failed to initiate shutdown of $leader after initiating replica fetcher thread shutdown", t)
       }
     }
     justShutdown
@@ -150,10 +79,10 @@ class ReplicaFetcherThread(name: String,
     // especially during shutdown. It is safe to catch the exception here without causing correctness
     // issue because we are going to shutdown the thread and will not re-use the leaderEndpoint anyway.
     try {
-      leaderEndpoint.close()
+      leader.close()
     } catch {
       case t: Throwable =>
-        error(s"Failed to close leader endpoint $leaderEndpoint after shutting down replica fetcher thread", t)
+        error(s"Failed to close $leader after shutting down replica fetcher thread", t)
     }
   }
 
@@ -206,115 +135,13 @@ class ReplicaFetcherThread(name: String,
 
   def maybeWarnIfOversizedRecords(records: MemoryRecords, topicPartition: TopicPartition): Unit = {
     // oversized messages don't cause replication to fail from fetch request version 3 (KIP-74)
-    if (fetchRequestVersion <= 2 && records.sizeInBytes > 0 && records.validBytes <= 0)
+    if (brokerConfig.fetchRequestVersion <= 2 && records.sizeInBytes > 0 && records.validBytes <= 0)
       error(s"Replication is failing due to a message that is greater than replica.fetch.max.bytes for partition $topicPartition. " +
         "This generally occurs when the max.message.bytes has been overridden to exceed this value and a suitably large " +
         "message has also been sent. To fix this problem increase replica.fetch.max.bytes in your broker config to be " +
         "equal or larger than your settings for max.message.bytes, both at a broker and topic level.")
   }
 
-
-  override protected def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
-    val clientResponse = try {
-      leaderEndpoint.sendRequest(fetchRequest)
-    } catch {
-      case t: Throwable =>
-        fetchSessionHandler.handleError(t)
-        throw t
-    }
-    val fetchResponse = clientResponse.responseBody.asInstanceOf[FetchResponse]
-    if (!fetchSessionHandler.handleResponse(fetchResponse, clientResponse.requestHeader().apiVersion())) {
-      // If we had a session topic ID related error, throw it, otherwise return an empty fetch data map.
-      if (fetchResponse.error == Errors.FETCH_SESSION_TOPIC_ID_ERROR) {
-        throw Errors.forCode(fetchResponse.error().code()).exception()
-      } else {
-        Map.empty
-      }
-    } else {
-      fetchResponse.responseData(fetchSessionHandler.sessionTopicNames, clientResponse.requestHeader().apiVersion()).asScala
-    }
-  }
-
-  override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = {
-    fetchOffsetFromLeader(topicPartition, currentLeaderEpoch, ListOffsetsRequest.EARLIEST_TIMESTAMP)
-  }
-
-  override protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = {
-    fetchOffsetFromLeader(topicPartition, currentLeaderEpoch, ListOffsetsRequest.LATEST_TIMESTAMP)
-  }
-
-  private def fetchOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int, earliestOrLatest: Long): Long = {
-    val topic = new ListOffsetsTopic()
-      .setName(topicPartition.topic)
-      .setPartitions(Collections.singletonList(
-          new ListOffsetsPartition()
-            .setPartitionIndex(topicPartition.partition)
-            .setCurrentLeaderEpoch(currentLeaderEpoch)
-            .setTimestamp(earliestOrLatest)))
-    val requestBuilder = ListOffsetsRequest.Builder.forReplica(listOffsetRequestVersion, replicaId)
-      .setTargetTimes(Collections.singletonList(topic))
-
-    val clientResponse = leaderEndpoint.sendRequest(requestBuilder)
-    val response = clientResponse.responseBody.asInstanceOf[ListOffsetsResponse]
-    val responsePartition = response.topics.asScala.find(_.name == topicPartition.topic).get
-      .partitions.asScala.find(_.partitionIndex == topicPartition.partition).get
-
-     Errors.forCode(responsePartition.errorCode) match {
-      case Errors.NONE =>
-        if (brokerConfig.interBrokerProtocolVersion.isAtLeast(IBP_0_10_1_IV2))
-          responsePartition.offset
-        else
-          responsePartition.oldStyleOffsets.get(0)
-      case error => throw error.exception
-    }
-  }
-
-  override def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = {
-    val partitionsWithError = mutable.Set[TopicPartition]()
-
-    val builder = fetchSessionHandler.newBuilder(partitionMap.size, false)
-    partitionMap.forKeyValue { (topicPartition, fetchState) =>
-      // We will not include a replica in the fetch request if it should be throttled.
-      if (fetchState.isReadyForFetch && !shouldFollowerThrottle(quota, fetchState, topicPartition)) {
-        try {
-          val logStartOffset = this.logStartOffset(topicPartition)
-          val lastFetchedEpoch = if (isTruncationOnFetchSupported)
-            fetchState.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
-          else
-            Optional.empty[Integer]
-          builder.add(topicPartition, new FetchRequest.PartitionData(
-            fetchState.topicId.getOrElse(Uuid.ZERO_UUID),
-            fetchState.fetchOffset,
-            logStartOffset,
-            fetchSize,
-            Optional.of(fetchState.currentLeaderEpoch),
-            lastFetchedEpoch))
-        } catch {
-          case _: KafkaStorageException =>
-            // The replica has already been marked offline due to log directory failure and the original failure should have already been logged.
-            // This partition should be removed from ReplicaFetcherThread soon by ReplicaManager.handleLogDirFailure()
-            partitionsWithError += topicPartition
-        }
-      }
-    }
-
-    val fetchData = builder.build()
-    val fetchRequestOpt = if (fetchData.sessionPartitions.isEmpty && fetchData.toForget.isEmpty) {
-      None
-    } else {
-      val version: Short = if (fetchRequestVersion >= 13 && !fetchData.canUseTopicIds) 12 else fetchRequestVersion
-      val requestBuilder = FetchRequest.Builder
-        .forReplica(version, replicaId, maxWait, minBytes, fetchData.toSend)
-        .setMaxBytes(maxBytes)
-        .removed(fetchData.toForget)
-        .replaced(fetchData.toReplace)
-        .metadata(fetchData.metadata)
-      Some(ReplicaFetch(fetchData.sessionPartitions(), requestBuilder))
-    }
-
-    ResultWithPartitions(fetchRequestOpt, partitionsWithError)
-  }
-
   /**
    * Truncate the log for each partition's epoch based on leader's returned epoch and offset.
    * The logic for finding the truncation offset is implemented in AbstractFetcherThread.getOffsetTruncationState
@@ -340,57 +167,4 @@ class ReplicaFetcherThread(name: String,
     partition.truncateFullyAndStartAt(offset, isFuture = false)
   }
 
-  override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
-
-    if (partitions.isEmpty) {
-      debug("Skipping leaderEpoch request since all partitions do not have an epoch")
-      return Map.empty
-    }
-
-    val topics = new OffsetForLeaderTopicCollection(partitions.size)
-    partitions.forKeyValue { (topicPartition, epochData) =>
-      var topic = topics.find(topicPartition.topic)
-      if (topic == null) {
-        topic = new OffsetForLeaderTopic().setTopic(topicPartition.topic)
-        topics.add(topic)
-      }
-      topic.partitions.add(epochData)
-    }
-
-    val epochRequest = OffsetsForLeaderEpochRequest.Builder.forFollower(
-      offsetForLeaderEpochRequestVersion, topics, brokerConfig.brokerId)
-    debug(s"Sending offset for leader epoch request $epochRequest")
-
-    try {
-      val response = leaderEndpoint.sendRequest(epochRequest)
-      val responseBody = response.responseBody.asInstanceOf[OffsetsForLeaderEpochResponse]
-      debug(s"Received leaderEpoch response $response")
-      responseBody.data.topics.asScala.flatMap { offsetForLeaderTopicResult =>
-        offsetForLeaderTopicResult.partitions.asScala.map { offsetForLeaderPartitionResult =>
-          val tp = new TopicPartition(offsetForLeaderTopicResult.topic, offsetForLeaderPartitionResult.partition)
-          tp -> offsetForLeaderPartitionResult
-        }
-      }.toMap
-    } catch {
-      case t: Throwable =>
-        warn(s"Error when sending leader epoch request for $partitions", t)
-
-        // if we get any unexpected exception, mark all partitions with an error
-        val error = Errors.forException(t)
-        partitions.map { case (tp, _) =>
-          tp -> new EpochEndOffset()
-            .setPartition(tp.partition)
-            .setErrorCode(error.code)
-        }
-    }
-  }
-
-  /**
-   *  To avoid ISR thrashing, we only throttle a replica on the follower if it's in the throttled replica list,
-   *  the quota is exceeded and the replica is not in sync.
-   */
-  private def shouldFollowerThrottle(quota: ReplicaQuota, fetchState: PartitionFetchState, topicPartition: TopicPartition): Boolean = {
-    !fetchState.isReplicaInSync && quota.isThrottled(topicPartition) && quota.isQuotaExceeded
-  }
-
 }
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
index 5a7a07730f..cb60384a6b 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
@@ -65,8 +65,8 @@ class AbstractFetcherManagerTest {
       currentLeaderEpoch = leaderEpoch,
       initOffset = fetchOffset)
 
-    when(fetcher.sourceBroker)
-      .thenReturn(new BrokerEndPoint(0, "localhost", 9092))
+    when(fetcher.leader)
+      .thenReturn(new MockLeaderEndPoint(new BrokerEndPoint(0, "localhost", 9092)))
     when(fetcher.addPartitions(Map(tp -> initialFetchState)))
       .thenReturn(Set(tp))
     when(fetcher.fetchState(tp))
@@ -127,8 +127,8 @@ class AbstractFetcherManagerTest {
       currentLeaderEpoch = leaderEpoch,
       initOffset = fetchOffset)
 
-    when(fetcher.sourceBroker)
-      .thenReturn(new BrokerEndPoint(0, "localhost", 9092))
+    when(fetcher.leader)
+      .thenReturn(new MockLeaderEndPoint(new BrokerEndPoint(0, "localhost", 9092)))
     when(fetcher.addPartitions(Map(tp -> initialFetchState)))
       .thenReturn(Set(tp))
     when(fetcher.isThreadFailed).thenReturn(true)
@@ -174,8 +174,8 @@ class AbstractFetcherManagerTest {
       initOffset = fetchOffset)
 
     // Simulate calls to different fetchers due to different leaders
-    when(fetcher.sourceBroker)
-      .thenReturn(new BrokerEndPoint(0, "localhost", 9092))
+    when(fetcher.leader)
+      .thenReturn(new MockLeaderEndPoint(new BrokerEndPoint(0, "localhost", 9092)))
     when(fetcher.addPartitions(Map(tp1 -> initialFetchState1)))
       .thenReturn(Set(tp1))
     when(fetcher.addPartitions(Map(tp2 -> initialFetchState2)))
@@ -288,11 +288,31 @@ class AbstractFetcherManagerTest {
     Utils.abs(tp.hashCode) % brokerNum
   }
 
+  private class MockLeaderEndPoint(sourceBroker: BrokerEndPoint) extends LeaderEndPoint {
+    override def initiateClose(): Unit = {}
+
+    override def close(): Unit = {}
+
+    override def brokerEndPoint(): BrokerEndPoint = sourceBroker
+
+    override def fetch(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = Map.empty
+
+    override def fetchEarliestOffset(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = 1
+
+    override def fetchLatestOffset(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = 1
+
+    override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = Map.empty
+
+    override def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = ResultWithPartitions(None, Set.empty)
+
+    override val isTruncationOnFetchSupported: Boolean = false
+  }
+
   private class TestResizeFetcherThread(sourceBroker: BrokerEndPoint, failedPartitions: FailedPartitions)
     extends AbstractFetcherThread(
       name = "test-resize-fetcher",
       clientId = "mock-fetcher",
-      sourceBroker,
+      leader = new MockLeaderEndPoint(sourceBroker),
       failedPartitions,
       fetchBackOffMs = 0,
       brokerTopicStats = new BrokerTopicStats) {
@@ -305,8 +325,6 @@ class AbstractFetcherManagerTest {
 
     override protected def truncateFullyAndStartAt(topicPartition: TopicPartition, offset: Long): Unit = {}
 
-    override protected def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = ResultWithPartitions(None, Set.empty)
-
     override protected def latestEpoch(topicPartition: TopicPartition): Option[Int] = Some(0)
 
     override protected def logStartOffset(topicPartition: TopicPartition): Long = 1
@@ -315,16 +333,7 @@ class AbstractFetcherManagerTest {
 
     override protected def endOffsetForEpoch(topicPartition: TopicPartition, epoch: Int): Option[OffsetAndEpoch] = Some(OffsetAndEpoch(1, 0))
 
-    override protected def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = Map.empty
-
-    override protected def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = Map.empty
-
-    override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = 1
-
-    override protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = 1
-
     override protected val isOffsetForLeaderEpochSupported: Boolean = false
-    override protected val isTruncationOnFetchSupported: Boolean = false
   }
 
 }
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
index e34069b6f0..cdd17b1af2 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
@@ -77,12 +77,13 @@ class AbstractFetcherThreadTest {
   @Test
   def testMetricsRemovedOnShutdown(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
     // add one partition to create the consumer lag metric
-    fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)))
-    fetcher.setLeaderState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.mockLeader.setLeaderState(partition, PartitionState(leaderEpoch = 0))
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.start()
 
@@ -104,12 +105,13 @@ class AbstractFetcherThreadTest {
   @Test
   def testConsumerLagRemovedWithPartition(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
     // add one partition to create the consumer lag metric
-    fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)))
-    fetcher.setLeaderState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.mockLeader.setLeaderState(partition, PartitionState(leaderEpoch = 0))
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
 
@@ -126,15 +128,16 @@ class AbstractFetcherThreadTest {
   @Test
   def testSimpleFetch(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
-    fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)))
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0,
       new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))
-    val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
 
@@ -148,19 +151,20 @@ class AbstractFetcherThreadTest {
     val partition = new TopicPartition("topic", 0)
     val fetchBackOffMs = 250
 
-    val fetcher = new MockFetcherThread(fetchBackOffMs = fetchBackOffMs) {
-      override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+      override def fetch(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
         throw new UnknownTopicIdException("Topic ID was unknown as expected for this test")
       }
-    }
+    }, fetchBackOffMs = fetchBackOffMs)
 
-    fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0)))
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0,
       new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))
-    val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     // Do work for the first time. This should result in all partitions in error.
     val timeBeforeFirst = System.currentTimeMillis()
@@ -187,27 +191,28 @@ class AbstractFetcherThreadTest {
     val partition3 = new TopicPartition("topic3", 0)
     val fetchBackOffMs = 250
 
-    val fetcher = new MockFetcherThread(fetchBackOffMs = fetchBackOffMs) {
-      override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+      override def fetch(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
         Map(partition1 -> new FetchData().setErrorCode(Errors.UNKNOWN_TOPIC_ID.code),
           partition2 -> new FetchData().setErrorCode(Errors.INCONSISTENT_TOPIC_ID.code),
           partition3 -> new FetchData().setErrorCode(Errors.NONE.code))
       }
-    }
+    }, fetchBackOffMs = fetchBackOffMs)
 
-    fetcher.setReplicaState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition1, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition1 -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0)))
-    fetcher.setReplicaState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition2, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition2 -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0)))
-    fetcher.setReplicaState(partition3, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition3, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition3 -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0)))
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0,
       new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))
-    val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
-    fetcher.setLeaderState(partition1, leaderState)
-    fetcher.setLeaderState(partition2, leaderState)
-    fetcher.setLeaderState(partition3, leaderState)
+    val leaderState = PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition1, leaderState)
+    fetcher.mockLeader.setLeaderState(partition2, leaderState)
+    fetcher.mockLeader.setLeaderState(partition3, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
 
@@ -227,15 +232,16 @@ class AbstractFetcherThreadTest {
   @Test
   def testFencedTruncation(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
-    fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)))
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 1,
       new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))
-    val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 1, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(Seq(batch), leaderEpoch = 1, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
 
@@ -252,17 +258,18 @@ class AbstractFetcherThreadTest {
   @Test
   def testFencedFetch(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
-    val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 0)
+    val replicaState = PartitionState(leaderEpoch = 0)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)))
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0,
       new SimpleRecord("a".getBytes),
       new SimpleRecord("b".getBytes))
-    val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
 
@@ -270,7 +277,7 @@ class AbstractFetcherThreadTest {
     assertEquals(2, replicaState.logEndOffset)
 
     // Bump the epoch on the leader
-    fetcher.leaderPartitionState(partition).leaderEpoch += 1
+    fetcher.mockLeader.leaderPartitionState(partition).leaderEpoch += 1
 
     fetcher.doWork()
 
@@ -282,16 +289,17 @@ class AbstractFetcherThreadTest {
   @Test
   def testUnknownLeaderEpochInTruncation(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
     // The replica's leader epoch is ahead of the leader
-    val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 1)
+    val replicaState = PartitionState(leaderEpoch = 1)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 1)), forceTruncation = true)
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0, new SimpleRecord("a".getBytes))
-    val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
 
@@ -300,7 +308,7 @@ class AbstractFetcherThreadTest {
     assertEquals(Some(Truncating), fetcher.fetchState(partition).map(_.state))
 
     // Bump the epoch on the leader
-    fetcher.leaderPartitionState(partition).leaderEpoch += 1
+    fetcher.mockLeader.leaderPartitionState(partition).leaderEpoch += 1
 
     // Now we can make progress
     fetcher.doWork()
@@ -312,21 +320,22 @@ class AbstractFetcherThreadTest {
   @Test
   def testUnknownLeaderEpochWhileFetching(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
     // This test is contrived because it shouldn't be possible to to see unknown leader epoch
     // in the Fetching state as the leader must validate the follower's epoch when it checks
     // the truncation offset.
 
-    val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 1)
+    val replicaState = PartitionState(leaderEpoch = 1)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 1)))
 
-    val leaderState = MockFetcherThread.PartitionState(Seq(
+    val leaderState = PartitionState(Seq(
       mkBatch(baseOffset = 0L, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
       mkBatch(baseOffset = 1L, leaderEpoch = 0, new SimpleRecord("b".getBytes))
     ), leaderEpoch = 1, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
 
@@ -335,7 +344,7 @@ class AbstractFetcherThreadTest {
     assertEquals(Some(Fetching), fetcher.fetchState(partition).map(_.state))
 
     // Somehow the leader epoch rewinds
-    fetcher.leaderPartitionState(partition).leaderEpoch = 0
+    fetcher.mockLeader.leaderPartitionState(partition).leaderEpoch = 0
 
     // We are stuck at the current offset
     fetcher.doWork()
@@ -343,7 +352,7 @@ class AbstractFetcherThreadTest {
     assertEquals(Some(Fetching), fetcher.fetchState(partition).map(_.state))
 
     // After returning to the right epoch, we can continue fetching
-    fetcher.leaderPartitionState(partition).leaderEpoch = 1
+    fetcher.mockLeader.leaderPartitionState(partition).leaderEpoch = 1
     fetcher.doWork()
     assertEquals(2, replicaState.logEndOffset)
     assertEquals(Some(Fetching), fetcher.fetchState(partition).map(_.state))
@@ -352,14 +361,14 @@ class AbstractFetcherThreadTest {
   @Test
   def testTruncation(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
-    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 0L)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 5)))
 
@@ -368,12 +377,13 @@ class AbstractFetcherThreadTest {
       mkBatch(baseOffset = 1, leaderEpoch = 3, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 5, new SimpleRecord("c".getBytes)))
 
-    val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     TestUtils.waitUntilTrue(() => {
       fetcher.doWork()
-      fetcher.replicaPartitionState(partition).log == fetcher.leaderPartitionState(partition).log
+      fetcher.replicaPartitionState(partition).log == fetcher.mockLeader.leaderPartitionState(partition).log
     }, "Failed to reconcile leader and follower logs")
 
     assertEquals(leaderState.logStartOffset, replicaState.logStartOffset)
@@ -385,29 +395,28 @@ class AbstractFetcherThreadTest {
   def testTruncateToHighWatermarkIfLeaderEpochRequestNotSupported(): Unit = {
     val highWatermark = 2L
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread {
-      override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
-        assertEquals(highWatermark, truncationState.offset)
-        assertTrue(truncationState.truncationCompleted)
-        super.truncate(topicPartition, truncationState)
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+        override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] =
+          throw new UnsupportedOperationException
+        override val isTruncationOnFetchSupported: Boolean = false
+    }) {
+        override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
+          assertEquals(highWatermark, truncationState.offset)
+          assertTrue(truncationState.truncationCompleted)
+          super.truncate(topicPartition, truncationState)
+        }
+        override protected val isOffsetForLeaderEpochSupported: Boolean = false
       }
 
-      override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] =
-        throw new UnsupportedOperationException
-
-      override protected val isOffsetForLeaderEpochSupported: Boolean = false
-
-      override protected val isTruncationOnFetchSupported: Boolean = false
-    }
-
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
-    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = 5, highWatermark)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), highWatermark, leaderEpoch = 5)))
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
 
@@ -420,27 +429,28 @@ class AbstractFetcherThreadTest {
   def testTruncateToHighWatermarkIfLeaderEpochInfoNotAvailable(): Unit = {
     val highWatermark = 2L
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread {
-      override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
-        assertEquals(highWatermark, truncationState.offset)
-        assertTrue(truncationState.truncationCompleted)
-        super.truncate(topicPartition, truncationState)
-      }
-
-      override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] =
-        throw new UnsupportedOperationException
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+        override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] =
+          throw new UnsupportedOperationException
+      }) {
+        override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
+          assertEquals(highWatermark, truncationState.offset)
+          assertTrue(truncationState.truncationCompleted)
+          super.truncate(topicPartition, truncationState)
+        }
 
-      override def latestEpoch(topicPartition: TopicPartition): Option[Int] = None
-    }
+        override def latestEpoch(topicPartition: TopicPartition): Option[Int] = None
+      }
 
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
-    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = 5, highWatermark)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), highWatermark, leaderEpoch = 5)))
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
 
@@ -453,7 +463,7 @@ class AbstractFetcherThreadTest {
   def testTruncateToHighWatermarkDuringRemovePartitions(): Unit = {
     val highWatermark = 2L
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread {
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint) {
       override def truncateToHighWatermark(partitions: Set[TopicPartition]): Unit = {
         removePartitions(Set(partition))
         super.truncateToHighWatermark(partitions)
@@ -467,9 +477,10 @@ class AbstractFetcherThreadTest {
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
-    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = 5, highWatermark)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), highWatermark, leaderEpoch = 5)))
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
 
@@ -482,14 +493,14 @@ class AbstractFetcherThreadTest {
     val partition = new TopicPartition("topic", 0)
 
     var truncations = 0
-    val fetcher = new MockFetcherThread {
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint) {
       override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
         truncations += 1
         super.truncate(topicPartition, truncationState)
       }
     }
 
-    val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 5)
+    val replicaState = PartitionState(leaderEpoch = 5)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 5)), forceTruncation = true)
 
@@ -498,8 +509,9 @@ class AbstractFetcherThreadTest {
       mkBatch(baseOffset = 1, leaderEpoch = 3, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 5, new SimpleRecord("c".getBytes)))
 
-    val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     // Do one round of truncation
     fetcher.doWork()
@@ -524,7 +536,7 @@ class AbstractFetcherThreadTest {
     assumeTrue(truncateOnFetch)
     val partition = new TopicPartition("topic", 0)
     var truncations = 0
-    val fetcher = new MockFetcherThread {
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint) {
       override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
         truncations += 1
         super.truncate(topicPartition, truncationState)
@@ -535,7 +547,7 @@ class AbstractFetcherThreadTest {
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
-    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 2L)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 2L)
     fetcher.setReplicaState(partition, replicaState)
 
     // Verify that truncation based on fetch response is performed if partition is owned by fetcher thread
@@ -564,14 +576,14 @@ class AbstractFetcherThreadTest {
   @Test
   def testFollowerFetchOutOfRangeHigh(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread()
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
-    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 4, highWatermark = 0L)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = 4, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 4)))
 
@@ -580,8 +592,9 @@ class AbstractFetcherThreadTest {
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
-    val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 4, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(leaderLog, leaderEpoch = 4, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     // initial truncation and verify that the log end offset is updated
     fetcher.doWork()
@@ -605,23 +618,24 @@ class AbstractFetcherThreadTest {
   def testFencedOffsetResetAfterOutOfRange(): Unit = {
     val partition = new TopicPartition("topic", 0)
     var fetchedEarliestOffset = false
-    val fetcher = new MockFetcherThread() {
-      override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
+      override def fetchEarliestOffset(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
         fetchedEarliestOffset = true
         throw new FencedLeaderEpochException(s"Epoch $leaderEpoch is fenced")
       }
-    }
+    })
 
     val replicaLog = Seq()
-    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 4, highWatermark = 0L)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = 4, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 4)))
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
-    val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 4, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(leaderLog, leaderEpoch = 4, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     // After the out of range error, we get a fenced error and remove the partition and mark as failed
     fetcher.doWork()
@@ -634,21 +648,22 @@ class AbstractFetcherThreadTest {
   @Test
   def testFollowerFetchOutOfRangeLow(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
     // The follower begins from an offset which is behind the leader's log start offset
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)))
 
-    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 0)))
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
-    val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 0, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(leaderLog, leaderEpoch = 0, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     // initial truncation and verify that the log start offset is updated
     fetcher.doWork()
@@ -663,7 +678,7 @@ class AbstractFetcherThreadTest {
 
     TestUtils.waitUntilTrue(() => {
       fetcher.doWork()
-      fetcher.replicaPartitionState(partition).log == fetcher.leaderPartitionState(partition).log
+      fetcher.replicaPartitionState(partition).log == fetcher.mockLeader.leaderPartitionState(partition).log
     }, "Failed to reconcile leader and follower logs")
 
     assertEquals(leaderState.logStartOffset, replicaState.logStartOffset)
@@ -674,28 +689,29 @@ class AbstractFetcherThreadTest {
   @Test
   def testRetryAfterUnknownLeaderEpochInLatestOffsetFetch(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher: MockFetcherThread = new MockFetcherThread {
+    val fetcher: MockFetcherThread = new MockFetcherThread(new MockLeaderEndPoint {
       val tries = new AtomicInteger(0)
-      override protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
+      override def fetchLatestOffset(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
         if (tries.getAndIncrement() == 0)
           throw new UnknownLeaderEpochException("Unexpected leader epoch")
-        super.fetchLatestOffsetFromLeader(topicPartition, leaderEpoch)
+        super.fetchLatestOffset(topicPartition, leaderEpoch)
       }
-    }
+    })
 
     // The follower begins from an offset which is behind the leader's log start offset
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)))
 
-    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 0)))
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
-    val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 0, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(leaderLog, leaderEpoch = 0, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     // initial truncation and initial error response handling
     fetcher.doWork()
@@ -703,7 +719,7 @@ class AbstractFetcherThreadTest {
 
     TestUtils.waitUntilTrue(() => {
       fetcher.doWork()
-      fetcher.replicaPartitionState(partition).log == fetcher.leaderPartitionState(partition).log
+      fetcher.replicaPartitionState(partition).log == fetcher.mockLeader.leaderPartitionState(partition).log
     }, "Failed to reconcile leader and follower logs")
 
     assertEquals(leaderState.logStartOffset, replicaState.logStartOffset)
@@ -715,10 +731,10 @@ class AbstractFetcherThreadTest {
   def testCorruptMessage(): Unit = {
     val partition = new TopicPartition("topic", 0)
 
-    val fetcher = new MockFetcherThread {
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
       var fetchedOnce = false
-      override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
-        val fetchedData = super.fetchFromLeader(fetchRequest)
+      override def fetch(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
+        val fetchedData = super.fetch(fetchRequest)
         if (!fetchedOnce) {
           val records = fetchedData.head._2.records.asInstanceOf[MemoryRecords]
           val buffer = records.buffer()
@@ -728,15 +744,16 @@ class AbstractFetcherThreadTest {
         }
         fetchedData
       }
-    }
+    })
 
-    fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)))
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0,
       new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))
-    val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
 
     fetcher.doWork() // fails with corrupt record
     fetcher.doWork() // should succeed
@@ -768,28 +785,33 @@ class AbstractFetcherThreadTest {
     val initialLeaderEpochOnFollower = 0
     val nextLeaderEpochOnFollower = initialLeaderEpochOnFollower + 1
 
-    val fetcher = new MockFetcherThread {
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
       var fetchEpochsFromLeaderOnce = false
       override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
         val fetchedEpochs = super.fetchEpochEndOffsets(partitions)
         if (!fetchEpochsFromLeaderOnce) {
-          // leader epoch changes while fetching epochs from leader
-          removePartitions(Set(partition))
-          setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = nextLeaderEpochOnFollower))
-          addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = nextLeaderEpochOnFollower)), forceTruncation = true)
+          responseCallback.apply()
           fetchEpochsFromLeaderOnce = true
         }
         fetchedEpochs
       }
+    })
+
+    def changeLeaderEpochWhileFetchEpoch(): Unit = {
+      fetcher.removePartitions(Set(partition))
+      fetcher.setReplicaState(partition, PartitionState(leaderEpoch = nextLeaderEpochOnFollower))
+      fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = nextLeaderEpochOnFollower)), forceTruncation = true)
     }
 
-    fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = initialLeaderEpochOnFollower))
+    fetcher.setReplicaState(partition, PartitionState(leaderEpoch = initialLeaderEpochOnFollower))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = initialLeaderEpochOnFollower)), forceTruncation = true)
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = initialLeaderEpochOnFollower, new SimpleRecord("c".getBytes)))
-    val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpochOnLeader, highWatermark = 0L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(leaderLog, leaderEpochOnLeader, highWatermark = 0L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setResponseCallback(changeLeaderEpochWhileFetchEpoch)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     // first round of truncation
     fetcher.doWork()
@@ -800,13 +822,13 @@ class AbstractFetcherThreadTest {
     assertEquals(Option(nextLeaderEpochOnFollower), fetcher.fetchState(partition).map(_.currentLeaderEpoch))
 
     if (leaderEpochOnLeader < nextLeaderEpochOnFollower) {
-      fetcher.setLeaderState(
-        partition, MockFetcherThread.PartitionState(leaderLog, nextLeaderEpochOnFollower, highWatermark = 0L))
+      fetcher.mockLeader.setLeaderState(
+        partition, PartitionState(leaderLog, nextLeaderEpochOnFollower, highWatermark = 0L))
     }
 
     // make sure the fetcher is now able to truncate and fetch
     fetcher.doWork()
-    assertEquals(fetcher.leaderPartitionState(partition).log, fetcher.replicaPartitionState(partition).log)
+    assertEquals(fetcher.mockLeader.leaderPartitionState(partition).log, fetcher.replicaPartitionState(partition).log)
   }
 
   @Test
@@ -816,24 +838,30 @@ class AbstractFetcherThreadTest {
     val initialLeaderEpochOnFollower = 0
     val nextLeaderEpochOnFollower = initialLeaderEpochOnFollower + 1
 
-    val fetcher = new MockFetcherThread {
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
       override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
         val fetchedEpochs = super.fetchEpochEndOffsets(partitions)
-        // leader epoch changes while fetching epochs from leader
-        // at the same time, the replica fetcher manager removes the partition
-        removePartitions(Set(partition))
-        setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = nextLeaderEpochOnFollower))
+        responseCallback.apply()
         fetchedEpochs
       }
+    })
+
+    def changeLeaderEpochDuringFetchEpoch(): Unit = {
+      // leader epoch changes while fetching epochs from leader
+      // at the same time, the replica fetcher manager removes the partition
+      fetcher.removePartitions(Set(partition))
+      fetcher.setReplicaState(partition, PartitionState(leaderEpoch = nextLeaderEpochOnFollower))
     }
 
-    fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = initialLeaderEpochOnFollower))
+    fetcher.setReplicaState(partition, PartitionState(leaderEpoch = initialLeaderEpochOnFollower))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = initialLeaderEpochOnFollower)))
 
     val leaderLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = initialLeaderEpochOnFollower, new SimpleRecord("c".getBytes)))
-    val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpochOnLeader, highWatermark = 0L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(leaderLog, leaderEpochOnLeader, highWatermark = 0L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setResponseCallback(changeLeaderEpochDuringFetchEpoch)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     // first round of work
     fetcher.doWork()
@@ -843,8 +871,8 @@ class AbstractFetcherThreadTest {
     assertEquals(None, fetcher.fetchState(partition).map(_.state))
     assertEquals(None, fetcher.fetchState(partition).map(_.currentLeaderEpoch))
 
-    fetcher.setLeaderState(
-      partition, MockFetcherThread.PartitionState(leaderLog, nextLeaderEpochOnFollower, highWatermark = 0L))
+    fetcher.mockLeader.setLeaderState(
+      partition, PartitionState(leaderLog, nextLeaderEpochOnFollower, highWatermark = 0L))
 
     // make sure the fetcher is able to continue work
     fetcher.doWork()
@@ -854,7 +882,7 @@ class AbstractFetcherThreadTest {
   @Test
   def testTruncationThrowsExceptionIfLeaderReturnsPartitionsNotRequestedInFetchEpochs(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread {
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint {
       override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
         val unrequestedTp = new TopicPartition("topic2", 0)
         super.fetchEpochEndOffsets(partitions).toMap + (unrequestedTp -> new EpochEndOffset()
@@ -863,11 +891,12 @@ class AbstractFetcherThreadTest {
           .setLeaderEpoch(0)
           .setEndOffset(0))
       }
-    }
+    })
 
-    fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)), forceTruncation = true)
-    fetcher.setLeaderState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.mockLeader.setLeaderState(partition, PartitionState(leaderEpoch = 0))
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     // first round of truncation should throw an exception
     assertThrows(classOf[IllegalStateException], () => fetcher.doWork())
@@ -875,7 +904,7 @@ class AbstractFetcherThreadTest {
 
   @Test
   def testFetcherThreadHandlingPartitionFailureDuringAppending(): Unit = {
-    val fetcherForAppend = new MockFetcherThread {
+    val fetcherForAppend = new MockFetcherThread(new MockLeaderEndPoint) {
       override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = {
         if (topicPartition == partition1) {
           throw new KafkaException()
@@ -889,7 +918,7 @@ class AbstractFetcherThreadTest {
 
   @Test
   def testFetcherThreadHandlingPartitionFailureDuringTruncation(): Unit = {
-    val fetcherForTruncation = new MockFetcherThread {
+    val fetcherForTruncation = new MockFetcherThread(new MockLeaderEndPoint) {
       override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
         if(topicPartition == partition1)
           throw new Exception()
@@ -903,13 +932,14 @@ class AbstractFetcherThreadTest {
 
   private def verifyFetcherThreadHandlingPartitionFailure(fetcher: MockFetcherThread): Unit = {
 
-    fetcher.setReplicaState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition1, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition1 -> initialFetchState(topicIds.get(partition1.topic), 0L, leaderEpoch = 0)), forceTruncation = true)
-    fetcher.setLeaderState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.mockLeader.setLeaderState(partition1, PartitionState(leaderEpoch = 0))
 
-    fetcher.setReplicaState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition2, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition2 -> initialFetchState(topicIds.get(partition2.topic), 0L, leaderEpoch = 0)), forceTruncation = true)
-    fetcher.setLeaderState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.mockLeader.setLeaderState(partition2, PartitionState(leaderEpoch = 0))
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     // processing data fails for partition1
     fetcher.doWork()
@@ -937,14 +967,14 @@ class AbstractFetcherThreadTest {
   @Test
   def testDivergingEpochs(): Unit = {
     val partition = new TopicPartition("topic", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
     val replicaLog = Seq(
       mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)),
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)))
 
-    val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 0L)
+    val replicaState = PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 0L)
     fetcher.setReplicaState(partition, replicaState)
     fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 5)))
     assertEquals(3L, replicaState.logEndOffset)
@@ -955,15 +985,16 @@ class AbstractFetcherThreadTest {
       mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)),
       mkBatch(baseOffset = 2, leaderEpoch = 5, new SimpleRecord("d".getBytes)))
 
-    val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 2L)
-    fetcher.setLeaderState(partition, leaderState)
+    val leaderState = PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 2L)
+    fetcher.mockLeader.setLeaderState(partition, leaderState)
+    fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState)
 
     fetcher.doWork()
     fetcher.verifyLastFetchedEpoch(partition, Some(2))
 
     TestUtils.waitUntilTrue(() => {
       fetcher.doWork()
-      fetcher.replicaPartitionState(partition).log == fetcher.leaderPartitionState(partition).log
+      fetcher.replicaPartitionState(partition).log == fetcher.mockLeader.leaderPartitionState(partition).log
     }, "Failed to reconcile leader and follower logs")
     fetcher.verifyLastFetchedEpoch(partition, Some(5))
   }
@@ -971,10 +1002,10 @@ class AbstractFetcherThreadTest {
   @Test
   def testMaybeUpdateTopicIds(): Unit = {
     val partition = new TopicPartition("topic1", 0)
-    val fetcher = new MockFetcherThread
+    val fetcher = new MockFetcherThread(new MockLeaderEndPoint)
 
     // Start with no topic IDs
-    fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.setReplicaState(partition, PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(None, 0L, leaderEpoch = 0)))
 
     def verifyFetchState(fetchState: Option[PartitionFetchState], expectedTopicId: Option[Uuid]): Unit = {
@@ -994,152 +1025,120 @@ class AbstractFetcherThreadTest {
     assertTrue(fetcher.fetchState(unknownPartition).isEmpty)
   }
 
-  object MockFetcherThread {
-    class PartitionState(var log: mutable.Buffer[RecordBatch],
-                         var leaderEpoch: Int,
-                         var logStartOffset: Long,
-                         var logEndOffset: Long,
-                         var highWatermark: Long)
-
-    object PartitionState {
-      def apply(log: Seq[RecordBatch], leaderEpoch: Int, highWatermark: Long): PartitionState = {
-        val logStartOffset = log.headOption.map(_.baseOffset).getOrElse(0L)
-        val logEndOffset = log.lastOption.map(_.nextOffset).getOrElse(0L)
-        new PartitionState(log.toBuffer, leaderEpoch, logStartOffset, logEndOffset, highWatermark)
-      }
+  class MockLeaderEndPoint(sourceBroker: BrokerEndPoint = new BrokerEndPoint(1, host = "localhost", port = Random.nextInt()))
+    extends LeaderEndPoint {
 
-      def apply(leaderEpoch: Int): PartitionState = {
-        apply(Seq(), leaderEpoch = leaderEpoch, highWatermark = 0L)
-      }
-    }
-  }
+    private val leaderPartitionStates = mutable.Map[TopicPartition, PartitionState]()
+    var responseCallback: () => Unit = () => {}
 
-  class MockFetcherThread(val replicaId: Int = 0, val leaderId: Int = 1, fetchBackOffMs: Int = 0)
-    extends AbstractFetcherThread("mock-fetcher",
-      clientId = "mock-fetcher",
-      sourceBroker = new BrokerEndPoint(leaderId, host = "localhost", port = Random.nextInt()),
-      failedPartitions,
-      fetchBackOffMs = fetchBackOffMs,
-      brokerTopicStats = new BrokerTopicStats) {
+    var replicaPartitionStateCallback: TopicPartition => Option[PartitionState] = { _ => Option.empty }
+    var replicaId: Int = 0
 
-    import MockFetcherThread.PartitionState
+    override val isTruncationOnFetchSupported: Boolean = truncateOnFetch
 
-    private val replicaPartitionStates = mutable.Map[TopicPartition, PartitionState]()
-    private val leaderPartitionStates = mutable.Map[TopicPartition, PartitionState]()
-    private var latestEpochDefault: Option[Int] = Some(0)
+    def leaderPartitionState(topicPartition: TopicPartition): PartitionState = {
+      leaderPartitionStates.getOrElse(topicPartition,
+        throw new IllegalArgumentException(s"Unknown partition $topicPartition"))
+    }
 
     def setLeaderState(topicPartition: TopicPartition, state: PartitionState): Unit = {
       leaderPartitionStates.put(topicPartition, state)
     }
 
-    def setReplicaState(topicPartition: TopicPartition, state: PartitionState): Unit = {
-      replicaPartitionStates.put(topicPartition, state)
+    def setResponseCallback(callback: () => Unit): Unit = {
+      responseCallback = callback
     }
 
-    def replicaPartitionState(topicPartition: TopicPartition): PartitionState = {
-      replicaPartitionStates.getOrElse(topicPartition,
-        throw new IllegalArgumentException(s"Unknown partition $topicPartition"))
+    def setReplicaPartitionStateCallback(callback: TopicPartition => PartitionState): Unit = {
+      replicaPartitionStateCallback = topicPartition => Some(callback(topicPartition))
     }
 
-    def leaderPartitionState(topicPartition: TopicPartition): PartitionState = {
-      leaderPartitionStates.getOrElse(topicPartition,
-        throw new IllegalArgumentException(s"Unknown partition $topicPartition"))
+    def setReplicaId(replicaId: Int): Unit = {
+      this.replicaId = replicaId
     }
 
-    def addPartitions(initialFetchStates: Map[TopicPartition, InitialFetchState], forceTruncation: Boolean): Set[TopicPartition] = {
-      latestEpochDefault = if (forceTruncation) None else Some(0)
-      val partitions = super.addPartitions(initialFetchStates)
-      latestEpochDefault = Some(0)
-      partitions
-    }
+    override def initiateClose(): Unit = {}
 
-    override def processPartitionData(topicPartition: TopicPartition,
-                                      fetchOffset: Long,
-                                      partitionData: FetchData): Option[LogAppendInfo] = {
-      val state = replicaPartitionState(topicPartition)
+    override def close(): Unit = {}
 
-      if (isTruncationOnFetchSupported && FetchResponse.isDivergingEpoch(partitionData)) {
-        val divergingEpoch = partitionData.divergingEpoch
-        truncateOnFetchResponse(Map(topicPartition -> new EpochEndOffset()
-          .setPartition(topicPartition.partition)
-          .setErrorCode(Errors.NONE.code)
-          .setLeaderEpoch(divergingEpoch.epoch)
-          .setEndOffset(divergingEpoch.endOffset)))
-        return None
-      }
+    override def brokerEndPoint(): BrokerEndPoint = sourceBroker
 
-      // Throw exception if the fetchOffset does not match the fetcherThread partition state
-      if (fetchOffset != state.logEndOffset)
-        throw new RuntimeException(s"Offset mismatch for partition $topicPartition: " +
-          s"fetched offset = $fetchOffset, log end offset = ${state.logEndOffset}.")
+    override def fetch(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
+      fetchRequest.fetchData.asScala.map { case (partition, fetchData) =>
+        val leaderState = leaderPartitionState(partition)
+        val epochCheckError = checkExpectedLeaderEpoch(fetchData.currentLeaderEpoch, leaderState)
+        val divergingEpoch = divergingEpochAndOffset(partition, fetchData.lastFetchedEpoch, fetchData.fetchOffset, leaderState)
 
-      // Now check message's crc
-      val batches = FetchResponse.recordsOrFail(partitionData).batches.asScala
-      var maxTimestamp = RecordBatch.NO_TIMESTAMP
-      var offsetOfMaxTimestamp = -1L
-      var lastOffset = state.logEndOffset
-      var lastEpoch: Option[Int] = None
+        val (error, records) = if (epochCheckError.isDefined) {
+          (epochCheckError.get, MemoryRecords.EMPTY)
+        } else if (fetchData.fetchOffset > leaderState.logEndOffset || fetchData.fetchOffset < leaderState.logStartOffset) {
+          (Errors.OFFSET_OUT_OF_RANGE, MemoryRecords.EMPTY)
+        } else if (divergingEpoch.nonEmpty) {
+          (Errors.NONE, MemoryRecords.EMPTY)
+        } else {
+          // for simplicity, we fetch only one batch at a time
+          val records = leaderState.log.find(_.baseOffset >= fetchData.fetchOffset) match {
+            case Some(batch) =>
+              val buffer = ByteBuffer.allocate(batch.sizeInBytes)
+              batch.writeTo(buffer)
+              buffer.flip()
+              MemoryRecords.readableRecords(buffer)
 
-      for (batch <- batches) {
-        batch.ensureValid()
-        if (batch.maxTimestamp > maxTimestamp) {
-          maxTimestamp = batch.maxTimestamp
-          offsetOfMaxTimestamp = batch.baseOffset
+            case None =>
+              MemoryRecords.EMPTY
+          }
+
+          (Errors.NONE, records)
         }
-        state.log.append(batch)
-        state.logEndOffset = batch.nextOffset
-        lastOffset = batch.lastOffset
-        lastEpoch = Some(batch.partitionLeaderEpoch)
-      }
+        val partitionData = new FetchData()
+          .setPartitionIndex(partition.partition)
+          .setErrorCode(error.code)
+          .setHighWatermark(leaderState.highWatermark)
+          .setLastStableOffset(leaderState.highWatermark)
+          .setLogStartOffset(leaderState.logStartOffset)
+          .setRecords(records)
+        divergingEpoch.foreach(partitionData.setDivergingEpoch)
 
-      state.logStartOffset = partitionData.logStartOffset
-      state.highWatermark = partitionData.highWatermark
+        (partition, partitionData)
+      }.toMap
+    }
 
-      Some(LogAppendInfo(firstOffset = Some(LogOffsetMetadata(fetchOffset)),
-        lastOffset = lastOffset,
-        lastLeaderEpoch = lastEpoch,
-        maxTimestamp = maxTimestamp,
-        offsetOfMaxTimestamp = offsetOfMaxTimestamp,
-        logAppendTime = Time.SYSTEM.milliseconds(),
-        logStartOffset = state.logStartOffset,
-        recordConversionStats = RecordConversionStats.EMPTY,
-        sourceCodec = NoCompressionCodec,
-        targetCodec = NoCompressionCodec,
-        shallowCount = batches.size,
-        validBytes = FetchResponse.recordsSize(partitionData),
-        offsetsMonotonic = true,
-        lastOffsetOfFirstBatch = batches.headOption.map(_.lastOffset).getOrElse(-1)))
+    override def fetchEarliestOffset(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
+      val leaderState = leaderPartitionState(topicPartition)
+      checkLeaderEpochAndThrow(leaderEpoch, leaderState)
+      leaderState.logStartOffset
     }
 
-    override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
-      val state = replicaPartitionState(topicPartition)
-      state.log = state.log.takeWhile { batch =>
-        batch.lastOffset < truncationState.offset
-      }
-      state.logEndOffset = state.log.lastOption.map(_.lastOffset + 1).getOrElse(state.logStartOffset)
-      state.highWatermark = math.min(state.highWatermark, state.logEndOffset)
+    override def fetchLatestOffset(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
+      val leaderState = leaderPartitionState(topicPartition)
+      checkLeaderEpochAndThrow(leaderEpoch, leaderState)
+      leaderState.logEndOffset
     }
 
-    override def truncateFullyAndStartAt(topicPartition: TopicPartition, offset: Long): Unit = {
-      val state = replicaPartitionState(topicPartition)
-      state.log.clear()
-      state.logStartOffset = offset
-      state.logEndOffset = offset
-      state.highWatermark = offset
+    override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
+      val endOffsets = mutable.Map[TopicPartition, EpochEndOffset]()
+      partitions.forKeyValue { (partition, epochData) =>
+        assert(partition.partition == epochData.partition,
+          "Partition must be consistent between TopicPartition and EpochData")
+        val leaderState = leaderPartitionState(partition)
+        val epochEndOffset = lookupEndOffsetForEpoch(partition, epochData, leaderState)
+        endOffsets.put(partition, epochEndOffset)
+      }
+      endOffsets
     }
 
     override def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = {
       val fetchData = mutable.Map.empty[TopicPartition, FetchRequest.PartitionData]
       partitionMap.foreach { case (partition, state) =>
         if (state.isReadyForFetch) {
-          val replicaState = replicaPartitionState(partition)
+          val replicaState = replicaPartitionStateCallback(partition).getOrElse(throw new IllegalArgumentException(s"Unknown partition $partition"))
           val lastFetchedEpoch = if (isTruncationOnFetchSupported)
             state.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
           else
             Optional.empty[Integer]
           fetchData.put(partition,
             new FetchRequest.PartitionData(state.topicId.getOrElse(Uuid.ZERO_UUID), state.fetchOffset, replicaState.logStartOffset,
-            1024 * 1024, Optional.of[Integer](state.currentLeaderEpoch), lastFetchedEpoch))
+              1024 * 1024, Optional.of[Integer](state.currentLeaderEpoch), lastFetchedEpoch))
         }
       }
       val fetchRequest = FetchRequest.Builder.forReplica(version, replicaId, 0, 1, fetchData.asJava)
@@ -1151,24 +1150,10 @@ class AbstractFetcherThreadTest {
       ResultWithPartitions(fetchRequestOpt, Set.empty)
     }
 
-    override def latestEpoch(topicPartition: TopicPartition): Option[Int] = {
-      val state = replicaPartitionState(topicPartition)
-      state.log.lastOption.map(_.partitionLeaderEpoch).orElse(latestEpochDefault)
-    }
-
-    override def logStartOffset(topicPartition: TopicPartition): Long = replicaPartitionState(topicPartition).logStartOffset
-
-    override def logEndOffset(topicPartition: TopicPartition): Long = replicaPartitionState(topicPartition).logEndOffset
-
-    override def endOffsetForEpoch(topicPartition: TopicPartition, epoch: Int): Option[OffsetAndEpoch] = {
-      val epochData = new EpochData()
-        .setPartition(topicPartition.partition)
-        .setLeaderEpoch(epoch)
-      val result = lookupEndOffsetForEpoch(topicPartition, epochData, replicaPartitionState(topicPartition))
-      if (result.endOffset == UNDEFINED_EPOCH_OFFSET)
-        None
-      else
-        Some(OffsetAndEpoch(result.endOffset, result.leaderEpoch))
+    private def checkLeaderEpochAndThrow(expectedEpoch: Int, partitionState: PartitionState): Unit = {
+      checkExpectedLeaderEpoch(expectedEpoch, partitionState).foreach { error =>
+        throw error.exception()
+      }
     }
 
     private def checkExpectedLeaderEpoch(expectedEpochOpt: Optional[Integer],
@@ -1194,13 +1179,6 @@ class AbstractFetcherThreadTest {
       }
     }
 
-    def verifyLastFetchedEpoch(partition: TopicPartition, expectedEpoch: Option[Int]): Unit = {
-      if (isTruncationOnFetchSupported) {
-        assertEquals(Some(Fetching), fetchState(partition).map(_.state))
-        assertEquals(expectedEpoch, fetchState(partition).flatMap(_.lastFetchedEpoch))
-      }
-    }
-
     private def divergingEpochAndOffset(topicPartition: TopicPartition,
                                         lastFetchedEpoch: Optional[Integer],
                                         fetchOffset: Long,
@@ -1212,8 +1190,8 @@ class AbstractFetcherThreadTest {
             .setLeaderEpoch(fetchEpoch)))(topicPartition)
 
         if (partitionState.log.isEmpty
-            || epochEndOffset.endOffset == UNDEFINED_EPOCH_OFFSET
-            || epochEndOffset.leaderEpoch == UNDEFINED_EPOCH)
+          || epochEndOffset.endOffset == UNDEFINED_EPOCH_OFFSET
+          || epochEndOffset.leaderEpoch == UNDEFINED_EPOCH)
           None
         else if (epochEndOffset.leaderEpoch < fetchEpoch || epochEndOffset.endOffset < fetchOffset) {
           Some(new FetchResponseData.EpochEndOffset()
@@ -1224,7 +1202,7 @@ class AbstractFetcherThreadTest {
       }
     }
 
-    private def lookupEndOffsetForEpoch(topicPartition: TopicPartition,
+    def lookupEndOffsetForEpoch(topicPartition: TopicPartition,
                                         epochData: EpochData,
                                         partitionState: PartitionState): EpochEndOffset = {
       checkExpectedLeaderEpoch(epochData.currentLeaderEpoch, partitionState).foreach { error =>
@@ -1256,81 +1234,156 @@ class AbstractFetcherThreadTest {
         .setPartition(topicPartition.partition)
         .setErrorCode(Errors.NONE.code)
     }
+  }
 
-    override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = {
-      val endOffsets = mutable.Map[TopicPartition, EpochEndOffset]()
-      partitions.forKeyValue { (partition, epochData) =>
-        assert(partition.partition == epochData.partition,
-          "Partition must be consistent between TopicPartition and EpochData")
-        val leaderState = leaderPartitionState(partition)
-        val epochEndOffset = lookupEndOffsetForEpoch(partition, epochData, leaderState)
-        endOffsets.put(partition, epochEndOffset)
-      }
-      endOffsets
+  class PartitionState(var log: mutable.Buffer[RecordBatch],
+                       var leaderEpoch: Int,
+                       var logStartOffset: Long,
+                       var logEndOffset: Long,
+                       var highWatermark: Long)
+
+  object PartitionState {
+    def apply(log: Seq[RecordBatch], leaderEpoch: Int, highWatermark: Long): PartitionState = {
+      val logStartOffset = log.headOption.map(_.baseOffset).getOrElse(0L)
+      val logEndOffset = log.lastOption.map(_.nextOffset).getOrElse(0L)
+      new PartitionState(log.toBuffer, leaderEpoch, logStartOffset, logEndOffset, highWatermark)
     }
 
-    override protected val isOffsetForLeaderEpochSupported: Boolean = true
+    def apply(leaderEpoch: Int): PartitionState = {
+      apply(Seq(), leaderEpoch = leaderEpoch, highWatermark = 0L)
+    }
+  }
 
-    override protected val isTruncationOnFetchSupported: Boolean = truncateOnFetch
+  class MockFetcherThread(val mockLeader : MockLeaderEndPoint, val replicaId: Int = 0, val leaderId: Int = 1, fetchBackOffMs: Int = 0)
+    extends AbstractFetcherThread("mock-fetcher",
+      clientId = "mock-fetcher",
+      leader = mockLeader,
+      failedPartitions,
+      fetchBackOffMs = fetchBackOffMs,
+      brokerTopicStats = new BrokerTopicStats) {
 
-    override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
-      fetchRequest.fetchData.asScala.map { case (partition, fetchData) =>
-        val leaderState = leaderPartitionState(partition)
-        val epochCheckError = checkExpectedLeaderEpoch(fetchData.currentLeaderEpoch, leaderState)
-        val divergingEpoch = divergingEpochAndOffset(partition, fetchData.lastFetchedEpoch, fetchData.fetchOffset, leaderState)
+    private val replicaPartitionStates = mutable.Map[TopicPartition, PartitionState]()
+    private var latestEpochDefault: Option[Int] = Some(0)
 
-        val (error, records) = if (epochCheckError.isDefined) {
-          (epochCheckError.get, MemoryRecords.EMPTY)
-        } else if (fetchData.fetchOffset > leaderState.logEndOffset || fetchData.fetchOffset < leaderState.logStartOffset) {
-          (Errors.OFFSET_OUT_OF_RANGE, MemoryRecords.EMPTY)
-        } else if (divergingEpoch.nonEmpty) {
-          (Errors.NONE, MemoryRecords.EMPTY)
-        } else {
-          // for simplicity, we fetch only one batch at a time
-          val records = leaderState.log.find(_.baseOffset >= fetchData.fetchOffset) match {
-            case Some(batch) =>
-              val buffer = ByteBuffer.allocate(batch.sizeInBytes)
-              batch.writeTo(buffer)
-              buffer.flip()
-              MemoryRecords.readableRecords(buffer)
+    def setReplicaState(topicPartition: TopicPartition, state: PartitionState): Unit = {
+      replicaPartitionStates.put(topicPartition, state)
+    }
 
-            case None =>
-              MemoryRecords.EMPTY
-          }
+    def replicaPartitionState(topicPartition: TopicPartition): PartitionState = {
+      replicaPartitionStates.getOrElse(topicPartition,
+        throw new IllegalArgumentException(s"Unknown partition $topicPartition"))
+    }
 
-          (Errors.NONE, records)
+    def addPartitions(initialFetchStates: Map[TopicPartition, InitialFetchState], forceTruncation: Boolean): Set[TopicPartition] = {
+      latestEpochDefault = if (forceTruncation) None else Some(0)
+      val partitions = super.addPartitions(initialFetchStates)
+      latestEpochDefault = Some(0)
+      partitions
+    }
+
+    override def processPartitionData(topicPartition: TopicPartition,
+                                      fetchOffset: Long,
+                                      partitionData: FetchData): Option[LogAppendInfo] = {
+      val state = replicaPartitionState(topicPartition)
+
+      if (leader.isTruncationOnFetchSupported && FetchResponse.isDivergingEpoch(partitionData)) {
+        val divergingEpoch = partitionData.divergingEpoch
+        truncateOnFetchResponse(Map(topicPartition -> new EpochEndOffset()
+          .setPartition(topicPartition.partition)
+          .setErrorCode(Errors.NONE.code)
+          .setLeaderEpoch(divergingEpoch.epoch)
+          .setEndOffset(divergingEpoch.endOffset)))
+        return None
+      }
+
+      // Throw exception if the fetchOffset does not match the fetcherThread partition state
+      if (fetchOffset != state.logEndOffset)
+        throw new RuntimeException(s"Offset mismatch for partition $topicPartition: " +
+          s"fetched offset = $fetchOffset, log end offset = ${state.logEndOffset}.")
+
+      // Now check message's crc
+      val batches = FetchResponse.recordsOrFail(partitionData).batches.asScala
+      var maxTimestamp = RecordBatch.NO_TIMESTAMP
+      var offsetOfMaxTimestamp = -1L
+      var lastOffset = state.logEndOffset
+      var lastEpoch: Option[Int] = None
+
+      for (batch <- batches) {
+        batch.ensureValid()
+        if (batch.maxTimestamp > maxTimestamp) {
+          maxTimestamp = batch.maxTimestamp
+          offsetOfMaxTimestamp = batch.baseOffset
         }
-        val partitionData = new FetchData()
-          .setPartitionIndex(partition.partition)
-          .setErrorCode(error.code)
-          .setHighWatermark(leaderState.highWatermark)
-          .setLastStableOffset(leaderState.highWatermark)
-          .setLogStartOffset(leaderState.logStartOffset)
-          .setRecords(records)
-        divergingEpoch.foreach(partitionData.setDivergingEpoch)
+        state.log.append(batch)
+        state.logEndOffset = batch.nextOffset
+        lastOffset = batch.lastOffset
+        lastEpoch = Some(batch.partitionLeaderEpoch)
+      }
 
-        (partition, partitionData)
-      }.toMap
+      state.logStartOffset = partitionData.logStartOffset
+      state.highWatermark = partitionData.highWatermark
+
+      Some(LogAppendInfo(firstOffset = Some(LogOffsetMetadata(fetchOffset)),
+        lastOffset = lastOffset,
+        lastLeaderEpoch = lastEpoch,
+        maxTimestamp = maxTimestamp,
+        offsetOfMaxTimestamp = offsetOfMaxTimestamp,
+        logAppendTime = Time.SYSTEM.milliseconds(),
+        logStartOffset = state.logStartOffset,
+        recordConversionStats = RecordConversionStats.EMPTY,
+        sourceCodec = NoCompressionCodec,
+        targetCodec = NoCompressionCodec,
+        shallowCount = batches.size,
+        validBytes = FetchResponse.recordsSize(partitionData),
+        offsetsMonotonic = true,
+        lastOffsetOfFirstBatch = batches.headOption.map(_.lastOffset).getOrElse(-1)))
     }
 
-    private def checkLeaderEpochAndThrow(expectedEpoch: Int, partitionState: PartitionState): Unit = {
-      checkExpectedLeaderEpoch(expectedEpoch, partitionState).foreach { error =>
-        throw error.exception()
+    override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
+      val state = replicaPartitionState(topicPartition)
+      state.log = state.log.takeWhile { batch =>
+        batch.lastOffset < truncationState.offset
       }
+      state.logEndOffset = state.log.lastOption.map(_.lastOffset + 1).getOrElse(state.logStartOffset)
+      state.highWatermark = math.min(state.highWatermark, state.logEndOffset)
     }
 
-    override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
-      val leaderState = leaderPartitionState(topicPartition)
-      checkLeaderEpochAndThrow(leaderEpoch, leaderState)
-      leaderState.logStartOffset
+    override def truncateFullyAndStartAt(topicPartition: TopicPartition, offset: Long): Unit = {
+      val state = replicaPartitionState(topicPartition)
+      state.log.clear()
+      state.logStartOffset = offset
+      state.logEndOffset = offset
+      state.highWatermark = offset
     }
 
-    override protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
-      val leaderState = leaderPartitionState(topicPartition)
-      checkLeaderEpochAndThrow(leaderEpoch, leaderState)
-      leaderState.logEndOffset
+    override def latestEpoch(topicPartition: TopicPartition): Option[Int] = {
+      val state = replicaPartitionState(topicPartition)
+      state.log.lastOption.map(_.partitionLeaderEpoch).orElse(latestEpochDefault)
+    }
+
+    override def logStartOffset(topicPartition: TopicPartition): Long = replicaPartitionState(topicPartition).logStartOffset
+
+    override def logEndOffset(topicPartition: TopicPartition): Long = replicaPartitionState(topicPartition).logEndOffset
+
+    override def endOffsetForEpoch(topicPartition: TopicPartition, epoch: Int): Option[OffsetAndEpoch] = {
+      val epochData = new EpochData()
+        .setPartition(topicPartition.partition)
+        .setLeaderEpoch(epoch)
+      val result = mockLeader.lookupEndOffsetForEpoch(topicPartition, epochData, replicaPartitionState(topicPartition))
+      if (result.endOffset == UNDEFINED_EPOCH_OFFSET)
+        None
+      else
+        Some(OffsetAndEpoch(result.endOffset, result.leaderEpoch))
     }
 
+    def verifyLastFetchedEpoch(partition: TopicPartition, expectedEpoch: Option[Int]): Unit = {
+      if (leader.isTruncationOnFetchSupported) {
+        assertEquals(Some(Fetching), fetchState(partition).map(_.state))
+        assertEquals(expectedEpoch, fetchState(partition).flatMap(_.lastFetchedEpoch))
+      }
+    }
+
+    override protected val isOffsetForLeaderEpochSupported: Boolean = true
   }
 
 }
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
index 93302f31ce..09939f43fd 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
@@ -16,7 +16,6 @@
   */
 package kafka.server
 
-import java.util.{Collections, Optional}
 import kafka.api.Request
 import kafka.cluster.{BrokerEndPoint, Partition}
 import kafka.log.{LogManager, UnifiedLog}
@@ -39,6 +38,7 @@ import org.mockito.ArgumentMatchers.{any, anyBoolean}
 import org.mockito.Mockito.{doNothing, mock, never, times, verify, when}
 import org.mockito.{ArgumentCaptor, ArgumentMatchers, Mockito}
 
+import java.util.{Collections, Optional}
 import scala.collection.{Map, Seq}
 import scala.jdk.CollectionConverters._
 
@@ -81,14 +81,15 @@ class ReplicaAlterLogDirsThreadTest {
     when(replicaManager.futureLogExists(t1p0)).thenReturn(false)
 
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, quotaManager)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = quotaManager,
-      brokerTopicStats = new BrokerTopicStats)
+      leader,
+      failedPartitions,
+      replicaManager,
+      quotaManager,
+      new BrokerTopicStats,
+      config.replicaFetchBackoffMs)
 
     val addedPartitions = thread.addPartitions(Map(t1p0 -> initialFetchState(0L)))
     assertEquals(Set.empty, addedPartitions)
@@ -148,14 +149,15 @@ class ReplicaAlterLogDirsThreadTest {
     mockFetchFromCurrentLog(tid1p0, fencedRequestData, config, replicaManager, fencedResponseData)
 
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, quotaManager)
     val thread = new ReplicaAlterLogDirsThread(
-      "alter-logs-dirs-thread",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = quotaManager,
-      brokerTopicStats = new BrokerTopicStats)
+      "alter-log-dirs-thread",
+      leader,
+      failedPartitions,
+      replicaManager,
+      quotaManager,
+      new BrokerTopicStats,
+      config.replicaFetchBackoffMs)
 
     // Initially we add the partition with an older epoch which results in an error
     thread.addPartitions(Map(t1p0 -> initialFetchState(fetchOffset = 0L, leaderEpoch - 1)))
@@ -246,14 +248,15 @@ class ReplicaAlterLogDirsThreadTest {
     mockFetchFromCurrentLog(tid1p0, requestData, config, replicaManager, responseData)
 
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, quotaManager)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = quotaManager,
-      brokerTopicStats = new BrokerTopicStats)
+      leader,
+      failedPartitions,
+      replicaManager,
+      quotaManager,
+      new BrokerTopicStats,
+      config.replicaFetchBackoffMs)
 
     thread.addPartitions(Map(t1p0 -> initialFetchState(fetchOffset = 0L, leaderEpoch)))
     assertTrue(thread.fetchState(t1p0).isDefined)
@@ -336,16 +339,17 @@ class ReplicaAlterLogDirsThreadTest {
         .setEndOffset(leoT1p1))
 
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, null)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread-test1",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = null,
-      brokerTopicStats = null)
-
-    val result = thread.fetchEpochEndOffsets(Map(
+      leader,
+      failedPartitions,
+      replicaManager,
+      null,
+      null,
+      config.replicaFetchBackoffMs)
+
+    val result = thread.leader.fetchEpochEndOffsets(Map(
       t1p0 -> new OffsetForLeaderPartition()
         .setPartition(t1p0.partition)
         .setLeaderEpoch(leaderEpochT1p0),
@@ -397,16 +401,17 @@ class ReplicaAlterLogDirsThreadTest {
       .thenThrow(new KafkaStorageException)
 
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, null)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread-test1",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = null,
-      brokerTopicStats = null)
-
-    val result = thread.fetchEpochEndOffsets(Map(
+      leader,
+      failedPartitions,
+      replicaManager,
+      null,
+      null,
+      config.replicaFetchBackoffMs)
+
+    val result = thread.leader.fetchEpochEndOffsets(Map(
       t1p0 -> new OffsetForLeaderPartition()
         .setPartition(t1p0.partition)
         .setLeaderEpoch(leaderEpoch),
@@ -498,14 +503,15 @@ class ReplicaAlterLogDirsThreadTest {
 
     //Create the thread
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, quotaManager)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread-test1",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = quotaManager,
-      brokerTopicStats = null)
+      leader,
+      failedPartitions,
+      replicaManager,
+      quotaManager,
+      null,
+      config.replicaFetchBackoffMs)
     thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t1p1 -> initialFetchState(0L)))
 
     //Run it
@@ -581,14 +587,15 @@ class ReplicaAlterLogDirsThreadTest {
 
     //Create the thread
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, quotaManager)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread-test1",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = quotaManager,
-      brokerTopicStats = null)
+      leader,
+      failedPartitions,
+      replicaManager,
+      quotaManager,
+      null,
+      config.replicaFetchBackoffMs)
     thread.addPartitions(Map(t1p0 -> initialFetchState(0L)))
 
     // First run will result in another offset for leader epoch request
@@ -636,14 +643,15 @@ class ReplicaAlterLogDirsThreadTest {
 
     //Create the thread
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, quotaManager)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread-test1",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = quotaManager,
-      brokerTopicStats = null)
+      leader,
+      failedPartitions,
+      replicaManager,
+      quotaManager,
+      null,
+      config.replicaFetchBackoffMs)
     thread.addPartitions(Map(t1p0 -> initialFetchState(initialFetchOffset)))
 
     //Run it
@@ -718,14 +726,15 @@ class ReplicaAlterLogDirsThreadTest {
 
     //Create the thread
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, quotaManager)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread-test1",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = quotaManager,
-      brokerTopicStats = null)
+      leader,
+      failedPartitions,
+      replicaManager,
+      quotaManager,
+      null,
+      config.replicaFetchBackoffMs)
     thread.addPartitions(Map(t1p0 -> initialFetchState(0L)))
 
     // Run thread 3 times (exactly number of times we mock exception for getReplicaOrException)
@@ -786,14 +795,15 @@ class ReplicaAlterLogDirsThreadTest {
 
     //Create the fetcher thread
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, quotaManager)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread-test1",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = quotaManager,
-      brokerTopicStats = null)
+      leader,
+      failedPartitions,
+      replicaManager,
+      quotaManager,
+      null,
+      config.replicaFetchBackoffMs)
     thread.addPartitions(Map(t1p0 -> initialFetchState(0L)))
 
     // loop few times
@@ -826,19 +836,20 @@ class ReplicaAlterLogDirsThreadTest {
     //Create the fetcher thread
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
     val leaderEpoch = 1
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, quotaManager)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread-test1",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = quotaManager,
-      brokerTopicStats = null)
+      leader,
+      failedPartitions,
+      replicaManager,
+      quotaManager,
+      null,
+      config.replicaFetchBackoffMs)
     thread.addPartitions(Map(
       t1p0 -> initialFetchState(0L, leaderEpoch),
       t1p1 -> initialFetchState(0L, leaderEpoch)))
 
-    val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = thread.buildFetch(Map(
+    val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = thread.leader.buildFetch(Map(
       t1p0 -> PartitionFetchState(Some(topicId), 150, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None),
       t1p1 -> PartitionFetchState(Some(topicId), 160, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None)))
 
@@ -876,20 +887,21 @@ class ReplicaAlterLogDirsThreadTest {
     //Create the fetcher thread
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
     val leaderEpoch = 1
+    val leader = new LocalLeaderEndPoint(endPoint, config, replicaManager, quotaManager)
     val thread = new ReplicaAlterLogDirsThread(
       "alter-logs-dirs-thread-test1",
-      sourceBroker = endPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      quota = quotaManager,
-      brokerTopicStats = null)
+      leader,
+      failedPartitions,
+      replicaManager,
+      quotaManager,
+      null,
+      config.replicaFetchBackoffMs)
     thread.addPartitions(Map(
       t1p0 -> initialFetchState(0L, leaderEpoch),
       t1p1 -> initialFetchState(0L, leaderEpoch)))
 
     // one partition is ready and one is truncating
-    val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = thread.buildFetch(Map(
+    val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = thread.leader.buildFetch(Map(
         t1p0 -> PartitionFetchState(Some(topicId), 150, None, leaderEpoch, state = Fetching, lastFetchedEpoch = None),
         t1p1 -> PartitionFetchState(Some(topicId), 160, None, leaderEpoch, state = Truncating, lastFetchedEpoch = None)))
 
@@ -903,7 +915,7 @@ class ReplicaAlterLogDirsThreadTest {
     assertEquals(150, fetchInfos.head._2.fetchOffset)
 
     // one partition is ready and one is delayed
-    val ResultWithPartitions(fetchRequest2Opt, partitionsWithError2) = thread.buildFetch(Map(
+    val ResultWithPartitions(fetchRequest2Opt, partitionsWithError2) = thread.leader.buildFetch(Map(
         t1p0 -> PartitionFetchState(Some(topicId), 140, None, leaderEpoch, state = Fetching, lastFetchedEpoch = None),
         t1p1 -> PartitionFetchState(Some(topicId), 160, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching, lastFetchedEpoch = None)))
 
@@ -917,7 +929,7 @@ class ReplicaAlterLogDirsThreadTest {
     assertEquals(140, fetchInfos2.head._2.fetchOffset)
 
     // both partitions are delayed
-    val ResultWithPartitions(fetchRequest3Opt, partitionsWithError3) = thread.buildFetch(Map(
+    val ResultWithPartitions(fetchRequest3Opt, partitionsWithError3) = thread.leader.buildFetch(Map(
         t1p0 -> PartitionFetchState(Some(topicId), 140, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching, lastFetchedEpoch = None),
         t1p1 -> PartitionFetchState(Some(topicId), 160, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching, lastFetchedEpoch = None)))
     assertTrue(fetchRequest3Opt.isEmpty, "Expected no fetch requests since all partitions are delayed")
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
index 954fe08d2e..e18dd29d69 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
@@ -20,20 +20,20 @@ import kafka.cluster.{BrokerEndPoint, Partition}
 import kafka.log.{LogAppendInfo, LogManager, UnifiedLog}
 import kafka.server.AbstractFetcherThread.ResultWithPartitions
 import kafka.server.QuotaFactory.UnboundedQuota
-import kafka.server.epoch.util.ReplicaFetcherMockBlockingSend
+import kafka.server.epoch.util.MockBlockingSender
 import kafka.server.metadata.ZkMetadataCache
 import kafka.utils.TestUtils
+import org.apache.kafka.clients.FetchSessionHandler
 import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.common.message.{FetchResponseData, UpdateMetadataRequestData}
 import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition
 import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
-import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.protocol.Errors._
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord}
 import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET}
 import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, UpdateMetadataRequest}
-import org.apache.kafka.common.utils.SystemTime
+import org.apache.kafka.common.utils.{LogContext, SystemTime}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, Test}
 import org.mockito.ArgumentCaptor
@@ -94,26 +94,36 @@ class ReplicaFetcherThreadTest {
     TestUtils.clearYammerMetrics()
   }
 
+  private def createReplicaFetcherThread(name: String,
+                                         fetcherId: Int,
+                                         brokerConfig: KafkaConfig,
+                                         failedPartitions: FailedPartitions,
+                                         replicaMgr: ReplicaManager,
+                                         quota: ReplicaQuota,
+                                         leaderEndpointBlockingSend: BlockingSend): ReplicaFetcherThread = {
+    val logContext = new LogContext(s"[ReplicaFetcher replicaId=${brokerConfig.brokerId}, leaderId=${leaderEndpointBlockingSend.brokerEndPoint().id}, fetcherId=$fetcherId] ")
+    val fetchSessionHandler = new FetchSessionHandler(logContext, leaderEndpointBlockingSend.brokerEndPoint().id)
+    val leader = new RemoteLeaderEndPoint(logContext.logPrefix, leaderEndpointBlockingSend, fetchSessionHandler, brokerConfig, replicaMgr, quota)
+    new ReplicaFetcherThread(name,
+      leader,
+      brokerConfig,
+      failedPartitions,
+      replicaMgr,
+      quota,
+      logContext.logPrefix)
+  }
+
   @Test
   def shouldSendLatestRequestVersionsByDefault(): Unit = {
     val props = TestUtils.createBrokerConfig(1, "localhost:1234")
     val config = KafkaConfig.fromProps(props)
+
     val replicaManager: ReplicaManager = mock(classOf[ReplicaManager])
     when(replicaManager.brokerTopicStats).thenReturn(mock(classOf[BrokerTopicStats]))
-    val thread = new ReplicaFetcherThread(
-      name = "bob",
-      fetcherId = 0,
-      sourceBroker = brokerEndPoint,
-      brokerConfig = config,
-      failedPartitions: FailedPartitions,
-      replicaMgr = replicaManager,
-      metrics =  new Metrics(),
-      time = new SystemTime(),
-      quota = UnboundedQuota,
-      leaderEndpointBlockingSend = None)
-    assertEquals(ApiKeys.FETCH.latestVersion, thread.fetchRequestVersion)
-    assertEquals(ApiKeys.OFFSET_FOR_LEADER_EPOCH.latestVersion, thread.offsetForLeaderEpochRequestVersion)
-    assertEquals(ApiKeys.LIST_OFFSETS.latestVersion, thread.listOffsetRequestVersion)
+
+    assertEquals(ApiKeys.FETCH.latestVersion, config.fetchRequestVersion)
+    assertEquals(ApiKeys.OFFSET_FOR_LEADER_EPOCH.latestVersion, config.offsetForLeaderEpochRequestVersion)
+    assertEquals(ApiKeys.LIST_OFFSETS.latestVersion, config.listOffsetRequestVersion)
   }
 
   @Test
@@ -152,9 +162,16 @@ class ReplicaFetcherThreadTest {
       t1p1 -> newOffsetForLeaderPartitionResult(t1p1, leaderEpoch, 1)).asJava
 
     //Create the fetcher thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
+    val mockNetwork = new MockBlockingSender(offsets, brokerEndPoint, new SystemTime())
 
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      quota,
+      mockNetwork)
 
     // topic 1 supports epoch, t2 doesn't.
     thread.addPartitions(Map(
@@ -216,24 +233,23 @@ class ReplicaFetcherThreadTest {
     val props = TestUtils.createBrokerConfig(1, "localhost:1234")
     val config = KafkaConfig.fromProps(props)
     val mockBlockingSend: BlockingSend = mock(classOf[BlockingSend])
-
+    when(mockBlockingSend.brokerEndPoint()).thenReturn(brokerEndPoint)
     when(mockBlockingSend.sendRequest(any())).thenThrow(new NullPointerException)
+
     val replicaManager: ReplicaManager = mock(classOf[ReplicaManager])
     when(replicaManager.brokerTopicStats).thenReturn(mock(classOf[BrokerTopicStats]))
 
-    val thread = new ReplicaFetcherThread(
-      name = "bob",
-      fetcherId = 0,
-      sourceBroker = brokerEndPoint,
-      brokerConfig = config,
-      failedPartitions: FailedPartitions,
-      replicaMgr = replicaManager,
-      metrics =  new Metrics(),
-      time = new SystemTime(),
-      quota = null,
-      leaderEndpointBlockingSend = Some(mockBlockingSend))
-
-    val result = thread.fetchEpochEndOffsets(Map(
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      null,
+      mockBlockingSend
+    )
+
+    val result = thread.leader.fetchEpochEndOffsets(Map(
       t1p0 -> new OffsetForLeaderPartition()
         .setPartition(t1p0.partition)
         .setLeaderEpoch(0),
@@ -295,9 +311,16 @@ class ReplicaFetcherThreadTest {
       t1p1 -> newOffsetForLeaderPartitionResult(t1p1, leaderEpoch, 1)).asJava
 
     //Create the fetcher thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager,
-      new Metrics, new SystemTime, UnboundedQuota, Some(mockNetwork))
+    val mockNetwork = new MockBlockingSender(offsets, brokerEndPoint, new SystemTime())
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      UnboundedQuota,
+      mockNetwork
+    )
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L)))
 
     //Loop 1
@@ -354,9 +377,16 @@ class ReplicaFetcherThreadTest {
       t2p1 -> newOffsetForLeaderPartitionResult(t2p1, leaderEpoch, 172)).asJava
 
     //Create the thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager,
-      new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val mockNetwork = new MockBlockingSender(offsetsReply, brokerEndPoint, new SystemTime())
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      quota,
+      mockNetwork
+    )
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t2p1 -> initialFetchState(Some(topicId2), 0L)))
 
     //Run it
@@ -407,9 +437,16 @@ class ReplicaFetcherThreadTest {
       t2p1 -> newOffsetForLeaderPartitionResult(t2p1, leaderEpochAtLeader, 202)).asJava
 
     //Create the thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions,
-      replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val mockNetwork = new MockBlockingSender(offsetsReply, brokerEndPoint, new SystemTime())
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      quota,
+      mockNetwork
+    )
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t2p1 -> initialFetchState(Some(topicId2), 0L)))
 
     //Run it
@@ -463,8 +500,16 @@ class ReplicaFetcherThreadTest {
       t1p1 -> newOffsetForLeaderPartitionResult(t1p1, 4, 143)).asJava
 
     // Create the fetcher thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val mockNetwork = new MockBlockingSender(offsets, brokerEndPoint, new SystemTime())
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      quota,
+      mockNetwork
+    )
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L)))
 
     // Loop 1 -- both topic partitions will need to fetch another leader epoch
@@ -533,8 +578,11 @@ class ReplicaFetcherThreadTest {
     stub(partition, replicaManager, log)
 
     // Create the fetcher thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(Collections.emptyMap(), brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork)) {
+    val mockNetwork = new MockBlockingSender(Collections.emptyMap(), brokerEndPoint, new SystemTime())
+    val logContext = new LogContext(s"[ReplicaFetcher replicaId=${config.brokerId}, leaderId=${brokerEndPoint.id}, fetcherId=0] ")
+    val fetchSessionHandler = new FetchSessionHandler(logContext, brokerEndPoint.id)
+    val leader = new RemoteLeaderEndPoint(logContext.logPrefix, mockNetwork, fetchSessionHandler, config, replicaManager, quota)
+    val thread = new ReplicaFetcherThread("bob", leader, config, failedPartitions, replicaManager, quota, logContext.logPrefix) {
       override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = None
     }
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), initialLEO), t1p1 -> initialFetchState(Some(topicId1), initialLEO)))
@@ -648,8 +696,16 @@ class ReplicaFetcherThreadTest {
       t1p1 -> newOffsetForLeaderPartitionResult(t1p1, UNDEFINED_EPOCH, 143)).asJava
 
     // Create the fetcher thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val mockNetwork = new MockBlockingSender(offsets, brokerEndPoint, new SystemTime())
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      quota,
+      mockNetwork
+    )
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L)))
 
     // Loop 1 -- both topic partitions will truncate to leader offset even though they don't know
@@ -704,8 +760,16 @@ class ReplicaFetcherThreadTest {
       t1p0 -> newOffsetForLeaderPartitionResult(t1p0, UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET)).asJava
 
     //Create the thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val mockNetwork = new MockBlockingSender(offsetsReply, brokerEndPoint, new SystemTime())
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      quota,
+      mockNetwork
+    )
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), initialFetchOffset)))
 
     //Run it
@@ -757,8 +821,16 @@ class ReplicaFetcherThreadTest {
     ).asJava
 
     //Create the thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val mockNetwork = new MockBlockingSender(offsetsReply, brokerEndPoint, new SystemTime())
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      quota,
+      mockNetwork
+    )
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L)))
 
     //Run thread 3 times
@@ -811,8 +883,16 @@ class ReplicaFetcherThreadTest {
     ).asJava
 
     //Create the fetcher thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val mockNetwork = new MockBlockingSender(offsetsReply, brokerEndPoint, new SystemTime())
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      quota,
+      mockNetwork
+    )
 
     //When
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L)))
@@ -863,9 +943,16 @@ class ReplicaFetcherThreadTest {
     ).asJava
 
     //Create the fetcher thread
-    val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(),
-      new SystemTime(), quota, Some(mockNetwork))
+    val mockNetwork = new MockBlockingSender(offsetsReply, brokerEndPoint, new SystemTime())
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      quota,
+      mockNetwork
+    )
 
     //When
     thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L)))
@@ -888,24 +975,24 @@ class ReplicaFetcherThreadTest {
   def shouldCatchExceptionFromBlockingSendWhenShuttingDownReplicaFetcherThread(): Unit = {
     val props = TestUtils.createBrokerConfig(1, "localhost:1234")
     val config = KafkaConfig.fromProps(props)
-    val mockBlockingSend: BlockingSend = mock(classOf[BlockingSend])
 
+    val mockBlockingSend: BlockingSend = mock(classOf[BlockingSend])
+    when(mockBlockingSend.brokerEndPoint()).thenReturn(brokerEndPoint)
     when(mockBlockingSend.initiateClose()).thenThrow(new IllegalArgumentException())
     when(mockBlockingSend.close()).thenThrow(new IllegalStateException())
+
     val replicaManager: ReplicaManager = mock(classOf[ReplicaManager])
     when(replicaManager.brokerTopicStats).thenReturn(mock(classOf[BrokerTopicStats]))
 
-    val thread = new ReplicaFetcherThread(
-      name = "bob",
-      fetcherId = 0,
-      sourceBroker = brokerEndPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      metrics =  new Metrics(),
-      time = new SystemTime(),
-      quota = null,
-      leaderEndpointBlockingSend = Some(mockBlockingSend))
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      null,
+      mockBlockingSend
+    )
     thread.start()
 
     // Verify that:
@@ -941,13 +1028,22 @@ class ReplicaFetcherThreadTest {
     val replicaQuota: ReplicaQuota = mock(classOf[ReplicaQuota])
     val log: UnifiedLog = mock(classOf[UnifiedLog])
 
+    when(mockBlockingSend.brokerEndPoint()).thenReturn(brokerEndPoint)
     when(replicaManager.brokerTopicStats).thenReturn(mock(classOf[BrokerTopicStats]))
     when(replicaManager.localLogOrException(any[TopicPartition])).thenReturn(log)
     when(replicaQuota.isThrottled(any[TopicPartition])).thenReturn(false)
     when(log.logStartOffset).thenReturn(0)
 
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions,
-      replicaManager, new Metrics(), new SystemTime(), replicaQuota, Some(mockBlockingSend))
+    val logContext = new LogContext(s"[ReplicaFetcher replicaId=${config.brokerId}, leaderId=${brokerEndPoint.id}, fetcherId=0] ")
+    val fetchSessionHandler = new FetchSessionHandler(logContext, brokerEndPoint.id)
+    val leader = new RemoteLeaderEndPoint(logContext.logPrefix, mockBlockingSend, fetchSessionHandler, config, replicaManager, replicaQuota)
+    val thread = new ReplicaFetcherThread("bob",
+      leader,
+      config,
+      failedPartitions,
+      replicaManager,
+      replicaQuota,
+      logContext.logPrefix)
 
     val leaderEpoch = 1
 
@@ -956,7 +1052,7 @@ class ReplicaFetcherThreadTest {
         t1p1 -> PartitionFetchState(Some(topicId1), 155, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None),
         t2p1 -> PartitionFetchState(Some(topicId2), 160, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None))
 
-    val ResultWithPartitions(fetchRequestOpt, _) = thread.buildFetch(partitionMap)
+    val ResultWithPartitions(fetchRequestOpt, _) = thread.leader.buildFetch(partitionMap)
 
     assertTrue(fetchRequestOpt.isDefined)
     val fetchRequestBuilder = fetchRequestOpt.get.fetchRequest
@@ -976,14 +1072,14 @@ class ReplicaFetcherThreadTest {
     responseData.put(tid2p1, new FetchResponseData.PartitionData())
     val fetchResponse = FetchResponse.of(Errors.NONE, 0, 123, responseData)
 
-    thread.fetchSessionHandler.handleResponse(fetchResponse, ApiKeys.FETCH.latestVersion())
+    leader.fetchSessionHandler.handleResponse(fetchResponse, ApiKeys.FETCH.latestVersion())
 
     // Remove t1p0, change the ID for t2p1, and keep t1p1 the same
     val newTopicId = Uuid.randomUuid()
     val partitionMap2 = Map(
       t1p1 -> PartitionFetchState(Some(topicId1), 155, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None),
       t2p1 -> PartitionFetchState(Some(newTopicId), 160, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None))
-    val ResultWithPartitions(fetchRequestOpt2, _) = thread.buildFetch(partitionMap2)
+    val ResultWithPartitions(fetchRequestOpt2, _) = thread.leader.buildFetch(partitionMap2)
 
     // Since t1p1 didn't change, we drop that one
     val partitionDataMap2 = partitionMap2.drop(1).map { case (tp, state) =>
@@ -1024,6 +1120,7 @@ class ReplicaFetcherThreadTest {
     val config = KafkaConfig.fromProps(props)
 
     val mockBlockingSend: BlockingSend = mock(classOf[BlockingSend])
+    when(mockBlockingSend.brokerEndPoint()).thenReturn(brokerEndPoint)
 
     val log: UnifiedLog = mock(classOf[UnifiedLog])
 
@@ -1039,17 +1136,15 @@ class ReplicaFetcherThreadTest {
 
     val replicaQuota: ReplicaQuota = mock(classOf[ReplicaQuota])
 
-    val thread = new ReplicaFetcherThread(
-      name = "bob",
-      fetcherId = 0,
-      sourceBroker = brokerEndPoint,
-      brokerConfig = config,
-      failedPartitions = failedPartitions,
-      replicaMgr = replicaManager,
-      metrics =  new Metrics(),
-      time = new SystemTime(),
-      quota = replicaQuota,
-      leaderEndpointBlockingSend = Some(mockBlockingSend))
+    val thread = createReplicaFetcherThread(
+      "bob",
+      0,
+      config,
+      failedPartitions,
+      replicaManager,
+      replicaQuota,
+      mockBlockingSend
+    )
 
     val records = MemoryRecords.withRecords(CompressionType.NONE,
       new SimpleRecord(1000, "foo".getBytes(StandardCharsets.UTF_8)))
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 623042d612..e5eef65d0a 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -30,9 +30,10 @@ import kafka.cluster.{BrokerEndPoint, Partition}
 import kafka.log._
 import kafka.server.QuotaFactory.{QuotaManagers, UnboundedQuota}
 import kafka.server.checkpoints.{LazyOffsetCheckpoints, OffsetCheckpointFile}
-import kafka.server.epoch.util.ReplicaFetcherMockBlockingSend
+import kafka.server.epoch.util.MockBlockingSender
 import kafka.utils.timer.MockTimer
 import kafka.utils.{MockScheduler, MockTime, TestUtils}
+import org.apache.kafka.clients.FetchSessionHandler
 import org.apache.kafka.common.errors.KafkaStorageException
 import org.apache.kafka.common.message.FetchResponseData
 import org.apache.kafka.common.message.LeaderAndIsrRequestData
@@ -50,7 +51,7 @@ import org.apache.kafka.common.requests.FetchRequest.PartitionData
 import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
 import org.apache.kafka.common.requests._
 import org.apache.kafka.common.security.auth.KafkaPrincipal
-import org.apache.kafka.common.utils.{Time, Utils}
+import org.apache.kafka.common.utils.{LogContext, Time, Utils}
 import org.apache.kafka.common.{IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.image.{AclsImage, ClientQuotasImage, ClusterImageTest, ConfigurationsImage, FeaturesImage, MetadataImage, ProducerIdsImage, TopicsDelta, TopicsImage}
 import org.apache.kafka.metadata.LeaderConstants.NO_LEADER
@@ -1974,7 +1975,7 @@ class ReplicaManagerTest {
       purgatoryName = "ElectLeader", timer, reaperEnabled = false)
 
     // Mock network client to show leader offset of 5
-    val blockingSend = new ReplicaFetcherMockBlockingSend(
+    val blockingSend = new MockBlockingSender(
       Map(topicPartitionObj -> new EpochEndOffset()
         .setPartition(topicPartitionObj.partition)
         .setErrorCode(Errors.NONE.code)
@@ -2005,9 +2006,12 @@ class ReplicaManagerTest {
         new ReplicaFetcherManager(config, this, metrics, time, threadNamePrefix, replicationQuotaManager) {
 
           override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): ReplicaFetcherThread = {
-            new ReplicaFetcherThread(s"ReplicaFetcherThread-$fetcherId", fetcherId,
-              sourceBroker, config, failedPartitions, replicaManager, metrics, time, quotaManager.follower, Some(blockingSend)) {
-
+            val logContext = new LogContext(s"[ReplicaFetcher replicaId=${config.brokerId}, leaderId=${sourceBroker.id}, " +
+              s"fetcherId=$fetcherId] ")
+            val fetchSessionHandler = new FetchSessionHandler(logContext, sourceBroker.id)
+            val leader = new RemoteLeaderEndPoint(logContext.logPrefix, blockingSend, fetchSessionHandler, config, replicaManager, quotaManager.follower)
+            new ReplicaFetcherThread(s"ReplicaFetcherThread-$fetcherId", leader, config, failedPartitions, replicaManager,
+              quotaManager.follower, logContext.logPrefix) {
               override def doWork(): Unit = {
                 // In case the thread starts before the partition is added by AbstractFetcherManager,
                 // add it here (it's a no-op if already added)
@@ -3197,7 +3201,7 @@ class ReplicaManagerTest {
       assertEquals(1, followerPartition.getLeaderEpoch)
 
       val fetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition)
-      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.sourceBroker))
+      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.leader.brokerEndPoint()))
     } finally {
       replicaManager.shutdown()
     }
@@ -3225,7 +3229,7 @@ class ReplicaManagerTest {
       assertEquals(0, followerPartition.getLeaderEpoch)
 
       val fetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition)
-      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.sourceBroker))
+      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.leader.brokerEndPoint()))
 
       // Append on a follower should fail
       val followerResponse = sendProducerAppend(replicaManager, topicPartition, numOfRecords)
@@ -3280,7 +3284,7 @@ class ReplicaManagerTest {
       assertEquals(0, followerPartition.getLeaderEpoch)
 
       val fetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition)
-      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.sourceBroker))
+      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.leader.brokerEndPoint()))
 
       // Apply the same delta again
       replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage)
@@ -3291,7 +3295,7 @@ class ReplicaManagerTest {
       assertEquals(0, noChangePartition.getLeaderEpoch)
 
       val noChangeFetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition)
-      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), noChangeFetcher.map(_.sourceBroker))
+      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), noChangeFetcher.map(_.leader.brokerEndPoint()))
     } finally {
       replicaManager.shutdown()
     }
@@ -3318,7 +3322,7 @@ class ReplicaManagerTest {
       assertEquals(0, followerPartition.getLeaderEpoch)
 
       val fetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition)
-      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.sourceBroker))
+      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.leader.brokerEndPoint()))
 
       // Apply changes that remove replica
       val notReplicaTopicsDelta = topicsChangeDelta(followerMetadataImage.topics(), otherId, true)
@@ -3355,7 +3359,7 @@ class ReplicaManagerTest {
       assertEquals(0, followerPartition.getLeaderEpoch)
 
       val fetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition)
-      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.sourceBroker))
+      assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.leader.brokerEndPoint()))
 
       // Apply changes that remove topic and replica
       val removeTopicsDelta = topicsDeleteDelta(followerMetadataImage.topics())
diff --git a/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochIntegrationTest.scala b/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochIntegrationTest.scala
index 3205606c81..ef2e820010 100644
--- a/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochIntegrationTest.scala
+++ b/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochIntegrationTest.scala
@@ -18,7 +18,7 @@ package kafka.server.epoch
 
 import kafka.cluster.BrokerEndPoint
 import kafka.server.KafkaConfig._
-import kafka.server.{BlockingSend, KafkaServer, ReplicaFetcherBlockingSend}
+import kafka.server.{BlockingSend, KafkaServer, BrokerBlockingSender}
 import kafka.utils.Implicits._
 import kafka.utils.TestUtils._
 import kafka.utils.{Logging, TestUtils}
@@ -231,7 +231,7 @@ class LeaderEpochIntegrationTest extends QuorumTestHarness with Logging {
     val node = from.metadataCache.getAliveBrokerNode(to.config.brokerId,
       from.config.interBrokerListenerName).get
     val endPoint = new BrokerEndPoint(node.id(), node.host(), node.port())
-    new ReplicaFetcherBlockingSend(endPoint, from.config, new Metrics(), new SystemTime(), 42, "TestFetcher", new LogContext())
+    new BrokerBlockingSender(endPoint, from.config, new Metrics(), new SystemTime(), 42, "TestFetcher", new LogContext())
   }
 
   private def waitForEpochChangeTo(topic: String, partition: Int, epoch: Int): Unit = {
diff --git a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala b/core/src/test/scala/unit/kafka/server/epoch/util/MockBlockingSender.scala
similarity index 95%
rename from core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
rename to core/src/test/scala/unit/kafka/server/epoch/util/MockBlockingSender.scala
index 8f3fcff371..ac1d8b5754 100644
--- a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
+++ b/core/src/test/scala/unit/kafka/server/epoch/util/MockBlockingSender.scala
@@ -16,19 +16,19 @@
   */
 package kafka.server.epoch.util
 
-import java.net.SocketTimeoutException
-import java.util
 import kafka.cluster.BrokerEndPoint
 import kafka.server.BlockingSend
 import org.apache.kafka.clients.{ClientRequest, ClientResponse, MockClient, NetworkClientUtils}
-import org.apache.kafka.common.message.{FetchResponseData, OffsetForLeaderEpochResponseData}
 import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.{EpochEndOffset, OffsetForLeaderTopicResult}
+import org.apache.kafka.common.message.{FetchResponseData, OffsetForLeaderEpochResponseData}
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.requests.AbstractRequest.Builder
 import org.apache.kafka.common.requests.{AbstractRequest, FetchResponse, OffsetsForLeaderEpochResponse, FetchMetadata => JFetchMetadata}
 import org.apache.kafka.common.utils.{SystemTime, Time}
 import org.apache.kafka.common.{Node, TopicIdPartition, TopicPartition, Uuid}
 
+import java.net.SocketTimeoutException
+import java.util
 import scala.collection.Map
 
 /**
@@ -39,9 +39,9 @@ import scala.collection.Map
   * OFFSET_FOR_LEADER_EPOCH with different offsets in response, it should update offsets using
   * setOffsetsForNextResponse
   */
-class ReplicaFetcherMockBlockingSend(offsets: java.util.Map[TopicPartition, EpochEndOffset],
-                                     sourceBroker: BrokerEndPoint,
-                                     time: Time)
+class MockBlockingSender(offsets: java.util.Map[TopicPartition, EpochEndOffset],
+                         sourceBroker: BrokerEndPoint,
+                         time: Time)
   extends BlockingSend {
 
   private val client = new MockClient(new SystemTime)
@@ -70,6 +70,8 @@ class ReplicaFetcherMockBlockingSend(offsets: java.util.Map[TopicPartition, Epoc
     this.topicIds = topicIds
   }
 
+  override def brokerEndPoint(): BrokerEndPoint = sourceBroker
+
   override def sendRequest(requestBuilder: Builder[_ <: AbstractRequest]): ClientResponse = {
     if (!NetworkClientUtils.awaitReady(client, sourceNode, time, 500))
       throw new SocketTimeoutException(s"Failed to connect within 500 ms")
diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
index 9ee9f704f5..86400ba1e9 100644
--- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
+++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
@@ -37,6 +37,8 @@ import kafka.server.MetadataCache;
 import kafka.server.OffsetAndEpoch;
 import kafka.server.OffsetTruncationState;
 import kafka.server.QuotaFactory;
+import kafka.server.RemoteLeaderEndPoint;
+import kafka.server.BrokerBlockingSender;
 import kafka.server.ReplicaFetcherThread;
 import kafka.server.ReplicaManager;
 import kafka.server.ReplicaQuota;
@@ -50,6 +52,7 @@ import kafka.utils.MockTime;
 import kafka.utils.Pool;
 import kafka.utils.TestUtils;
 import kafka.zk.KafkaZkClient;
+import org.apache.kafka.clients.FetchSessionHandler;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.Uuid;
@@ -66,6 +69,7 @@ import org.apache.kafka.common.record.RecordsSend;
 import org.apache.kafka.common.requests.FetchRequest;
 import org.apache.kafka.common.requests.FetchResponse;
 import org.apache.kafka.common.requests.UpdateMetadataRequest;
+import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.server.common.MetadataVersion;
@@ -117,6 +121,7 @@ public class ReplicaFetcherThreadBenchmark {
     private Pool<TopicPartition, Partition> pool = new Pool<TopicPartition, Partition>(Option.empty());
     private Metrics metrics = new Metrics();
     private ReplicaManager replicaManager;
+    private ReplicaQuota replicaQuota;
     private Option<Uuid> topicId = Option.apply(Uuid.randomUuid());
 
     @Setup(Level.Trial)
@@ -229,13 +234,28 @@ public class ReplicaFetcherThreadBenchmark {
             setLogDirFailureChannel(new LogDirFailureChannel(logDirs.size())).
             setAlterPartitionManager(TestUtils.createAlterIsrManager()).
             build();
-        fetcher = new ReplicaFetcherBenchThread(config, replicaManager, pool);
+        replicaQuota = new ReplicaQuota() {
+            @Override
+            public boolean isQuotaExceeded() {
+                return false;
+            }
+
+            @Override
+            public void record(long value) {
+            }
+
+            @Override
+            public boolean isThrottled(TopicPartition topicPartition) {
+                return false;
+            }
+        };
+        fetcher = new ReplicaFetcherBenchThread(config, replicaManager, replicaQuota, pool);
         fetcher.addPartitions(initialFetchStates);
         // force a pass to move partitions to fetching state. We do this in the setup phase
         // so that we do not measure this time as part of the steady state work
         fetcher.doWork();
         // handle response to engage the incremental fetch session handler
-        fetcher.fetchSessionHandler().handleResponse(FetchResponse.of(Errors.NONE, 0, 999, initialFetched), ApiKeys.FETCH.latestVersion());
+        ((RemoteLeaderEndPoint) fetcher.leader()).fetchSessionHandler().handleResponse(FetchResponse.of(Errors.NONE, 0, 999, initialFetched), ApiKeys.FETCH.latestVersion());
     }
 
     @TearDown(Level.Trial)
@@ -286,33 +306,59 @@ public class ReplicaFetcherThreadBenchmark {
 
         ReplicaFetcherBenchThread(KafkaConfig config,
                                   ReplicaManager replicaManager,
+                                  ReplicaQuota replicaQuota,
                                   Pool<TopicPartition,
                                   Partition> partitions) {
             super("name",
-                    3,
-                    new BrokerEndPoint(3, "host", 3000),
-                    config,
-                    new FailedPartitions(),
-                    replicaManager,
-                    new Metrics(),
-                    Time.SYSTEM,
-                    new ReplicaQuota() {
+                    new RemoteLeaderEndPoint(
+                            String.format("[ReplicaFetcher replicaId=%d, leaderId=%d, fetcherId=%d", config.brokerId(), 3, 3),
+                            new BrokerBlockingSender(
+                                    new BrokerEndPoint(3, "host", 3000),
+                                    config,
+                                    new Metrics(),
+                                    Time.SYSTEM,
+                                    3,
+                                    String.format("broker-%d-fetcher-%d", 3, 3),
+                                    new LogContext(String.format("[ReplicaFetcher replicaId=%d, leaderId=%d, fetcherId=%d", config.brokerId(), 3, 3))
+                            ),
+                            new FetchSessionHandler(
+                                    new LogContext(String.format("[ReplicaFetcher replicaId=%d, leaderId=%d, fetcherId=%d", config.brokerId(), 3, 3)), 3),
+                            config,
+                            replicaManager,
+                            replicaQuota
+                    ) {
                         @Override
-                        public boolean isQuotaExceeded() {
-                            return false;
+                        public long fetchEarliestOffset(TopicPartition topicPartition, int currentLeaderEpoch) {
+                            return 0;
                         }
 
                         @Override
-                        public void record(long value) {
+                        public Map<TopicPartition, EpochEndOffset> fetchEpochEndOffsets(Map<TopicPartition, OffsetForLeaderPartition> partitions) {
+                            scala.collection.mutable.Map<TopicPartition, EpochEndOffset> endOffsets = new scala.collection.mutable.HashMap<>();
+                            Iterator<TopicPartition> iterator = partitions.keys().iterator();
+                            while (iterator.hasNext()) {
+                                TopicPartition tp = iterator.next();
+                                endOffsets.put(tp, new EpochEndOffset()
+                                        .setPartition(tp.partition())
+                                        .setErrorCode(Errors.NONE.code())
+                                        .setLeaderEpoch(0)
+                                        .setEndOffset(100));
+                            }
+                            return endOffsets;
                         }
 
                         @Override
-                        public boolean isThrottled(TopicPartition topicPartition) {
-                            return false;
+                        public Map<TopicPartition, FetchResponseData.PartitionData> fetch(FetchRequest.Builder fetchRequest) {
+                            return new scala.collection.mutable.HashMap<>();
                         }
                     },
-                    Option.empty());
-            
+                    config,
+                    new FailedPartitions(),
+                    replicaManager,
+                    replicaQuota,
+                    String.format("[ReplicaFetcher replicaId=%d, leaderId=%d, fetcherId=%d", config.brokerId(), 3, 3)
+            );
+
             pool = partitions;
         }
 
@@ -346,30 +392,5 @@ public class ReplicaFetcherThreadBenchmark {
                                                           FetchResponseData.PartitionData partitionData) {
             return Option.empty();
         }
-
-        @Override
-        public long fetchEarliestOffsetFromLeader(TopicPartition topicPartition, int currentLeaderEpoch) {
-            return 0;
-        }
-
-        @Override
-        public Map<TopicPartition, EpochEndOffset> fetchEpochEndOffsets(Map<TopicPartition, OffsetForLeaderPartition> partitions) {
-            scala.collection.mutable.Map<TopicPartition, EpochEndOffset> endOffsets = new scala.collection.mutable.HashMap<>();
-            Iterator<TopicPartition> iterator = partitions.keys().iterator();
-            while (iterator.hasNext()) {
-                TopicPartition tp = iterator.next();
-                endOffsets.put(tp, new EpochEndOffset()
-                    .setPartition(tp.partition())
-                    .setErrorCode(Errors.NONE.code())
-                    .setLeaderEpoch(0)
-                    .setEndOffset(100));
-            }
-            return endOffsets;
-        }
-
-        @Override
-        public Map<TopicPartition, FetchResponseData.PartitionData> fetchFromLeader(FetchRequest.Builder fetchRequest) {
-            return new scala.collection.mutable.HashMap<>();
-        }
     }
 }