You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2022/05/10 20:24:47 UTC

[kafka] branch trunk updated: MINOR: Create case class to encapsulate fetch parameters and simplify handling (#12082)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 7730476603 MINOR: Create case class to encapsulate fetch parameters and simplify handling (#12082)
7730476603 is described below

commit 773047660359bf6b551d06763eeff80bc551b58a
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Tue May 10 13:24:23 2022 -0700

    MINOR: Create case class to encapsulate fetch parameters and simplify handling (#12082)
    
    This patch adds a new case class `FetchParams` which encapsulates the parameters of the fetch request. It then uses this class in `DelayedFetch` directly instead of `FetchMetadata`. The intent is to reduce the number of things we need to change whenever we need to pass through new parameters. The patch also cleans up `ReplicaManagerTest` for more consistent usage.
    
    Reviewers: David Jacot <dj...@confluent.io>
---
 core/src/main/scala/kafka/api/Request.scala        |   4 +
 .../src/main/scala/kafka/server/DelayedFetch.scala |  86 ++--
 .../main/scala/kafka/server/FetchDataInfo.scala    |  60 ++-
 core/src/main/scala/kafka/server/KafkaApis.scala   |  74 ++--
 .../kafka/server/ReplicaAlterLogDirsThread.scala   |  25 +-
 .../main/scala/kafka/server/ReplicaManager.scala   |  56 +--
 .../kafka/server/DelayedFetchTest.scala            |  55 +--
 .../scala/unit/kafka/server/KafkaApisTest.scala    |  28 +-
 .../server/ReplicaAlterLogDirsThreadTest.scala     |  54 ++-
 .../server/ReplicaManagerConcurrencyTest.scala     |  20 +-
 .../kafka/server/ReplicaManagerQuotasTest.scala    |  50 ++-
 .../unit/kafka/server/ReplicaManagerTest.scala     | 487 +++++++++++----------
 12 files changed, 538 insertions(+), 461 deletions(-)

diff --git a/core/src/main/scala/kafka/api/Request.scala b/core/src/main/scala/kafka/api/Request.scala
index 653b5f653a..6c405a45b0 100644
--- a/core/src/main/scala/kafka/api/Request.scala
+++ b/core/src/main/scala/kafka/api/Request.scala
@@ -25,6 +25,10 @@ object Request {
   // Broker ids are non-negative int.
   def isValidBrokerId(brokerId: Int): Boolean = brokerId >= 0
 
+  def isConsumer(replicaId: Int): Boolean = {
+    replicaId < 0 && replicaId != FutureLocalReplicaId
+  }
+
   def describeReplicaId(replicaId: Int): String = {
     replicaId match {
       case OrdinaryConsumerId => "consumer"
diff --git a/core/src/main/scala/kafka/server/DelayedFetch.scala b/core/src/main/scala/kafka/server/DelayedFetch.scala
index 8d38ef8b6d..3eb8eedf4c 100644
--- a/core/src/main/scala/kafka/server/DelayedFetch.scala
+++ b/core/src/main/scala/kafka/server/DelayedFetch.scala
@@ -23,7 +23,6 @@ import kafka.metrics.KafkaMetricsGroup
 import org.apache.kafka.common.TopicIdPartition
 import org.apache.kafka.common.errors._
 import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.replica.ClientMetadata
 import org.apache.kafka.common.requests.FetchRequest.PartitionData
 import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET}
 
@@ -38,36 +37,23 @@ case class FetchPartitionStatus(startOffsetMetadata: LogOffsetMetadata, fetchInf
   }
 }
 
-/**
- * The fetch metadata maintained by the delayed fetch operation
- */
-case class FetchMetadata(fetchMinBytes: Int,
-                         fetchMaxBytes: Int,
-                         hardMaxBytesLimit: Boolean,
-                         fetchOnlyLeader: Boolean,
-                         fetchIsolation: FetchIsolation,
-                         isFromFollower: Boolean,
-                         replicaId: Int,
-                         fetchPartitionStatus: Seq[(TopicIdPartition, FetchPartitionStatus)]) {
-
-  override def toString = "FetchMetadata(minBytes=" + fetchMinBytes + ", " +
-    "maxBytes=" + fetchMaxBytes + ", " +
-    "onlyLeader=" + fetchOnlyLeader + ", " +
-    "fetchIsolation=" + fetchIsolation + ", " +
-    "replicaId=" + replicaId + ", " +
-    "partitionStatus=" + fetchPartitionStatus + ")"
-}
 /**
  * A delayed fetch operation that can be created by the replica manager and watched
  * in the fetch operation purgatory
  */
-class DelayedFetch(delayMs: Long,
-                   fetchMetadata: FetchMetadata,
-                   replicaManager: ReplicaManager,
-                   quota: ReplicaQuota,
-                   clientMetadata: Option[ClientMetadata],
-                   responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit)
-  extends DelayedOperation(delayMs) {
+class DelayedFetch(
+  params: FetchParams,
+  fetchPartitionStatus: Seq[(TopicIdPartition, FetchPartitionStatus)],
+  replicaManager: ReplicaManager,
+  quota: ReplicaQuota,
+  responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit
+) extends DelayedOperation(params.maxWaitMs) {
+
+  override def toString: String = {
+    s"DelayedFetch(params=$params" +
+      s", numPartitions=${fetchPartitionStatus.size}" +
+      ")"
+  }
 
   /**
    * The operation can be completed if:
@@ -84,16 +70,16 @@ class DelayedFetch(delayMs: Long,
    */
   override def tryComplete(): Boolean = {
     var accumulatedSize = 0
-    fetchMetadata.fetchPartitionStatus.foreach {
+    fetchPartitionStatus.foreach {
       case (topicIdPartition, fetchStatus) =>
         val fetchOffset = fetchStatus.startOffsetMetadata
         val fetchLeaderEpoch = fetchStatus.fetchInfo.currentLeaderEpoch
         try {
           if (fetchOffset != LogOffsetMetadata.UnknownOffsetMetadata) {
             val partition = replicaManager.getPartitionOrException(topicIdPartition.topicPartition)
-            val offsetSnapshot = partition.fetchOffsetSnapshot(fetchLeaderEpoch, fetchMetadata.fetchOnlyLeader)
+            val offsetSnapshot = partition.fetchOffsetSnapshot(fetchLeaderEpoch, params.fetchOnlyLeader)
 
-            val endOffset = fetchMetadata.fetchIsolation match {
+            val endOffset = params.isolation match {
               case FetchLogEnd => offsetSnapshot.logEndOffset
               case FetchHighWatermark => offsetSnapshot.highWatermark
               case FetchTxnCommitted => offsetSnapshot.lastStableOffset
@@ -105,19 +91,19 @@ class DelayedFetch(delayMs: Long,
             if (endOffset.messageOffset != fetchOffset.messageOffset) {
               if (endOffset.onOlderSegment(fetchOffset)) {
                 // Case F, this can happen when the new fetch operation is on a truncated leader
-                debug(s"Satisfying fetch $fetchMetadata since it is fetching later segments of partition $topicIdPartition.")
+                debug(s"Satisfying fetch $this since it is fetching later segments of partition $topicIdPartition.")
                 return forceComplete()
               } else if (fetchOffset.onOlderSegment(endOffset)) {
                 // Case F, this can happen when the fetch operation is falling behind the current segment
                 // or the partition has just rolled a new segment
-                debug(s"Satisfying fetch $fetchMetadata immediately since it is fetching older segments.")
+                debug(s"Satisfying fetch $this immediately since it is fetching older segments.")
                 // We will not force complete the fetch request if a replica should be throttled.
-                if (!fetchMetadata.isFromFollower || !replicaManager.shouldLeaderThrottle(quota, partition, fetchMetadata.replicaId))
+                if (!params.isFromFollower || !replicaManager.shouldLeaderThrottle(quota, partition, params.replicaId))
                   return forceComplete()
               } else if (fetchOffset.messageOffset < endOffset.messageOffset) {
                 // we take the partition fetch size as upper bound when accumulating the bytes (skip if a throttled partition)
                 val bytesAvailable = math.min(endOffset.positionDiff(fetchOffset), fetchStatus.fetchInfo.maxBytes)
-                if (!fetchMetadata.isFromFollower || !replicaManager.shouldLeaderThrottle(quota, partition, fetchMetadata.replicaId))
+                if (!params.isFromFollower || !replicaManager.shouldLeaderThrottle(quota, partition, params.replicaId))
                   accumulatedSize += bytesAvailable
               }
             }
@@ -131,7 +117,7 @@ class DelayedFetch(delayMs: Long,
                 debug(s"Could not obtain last offset for leader epoch for partition $topicIdPartition, epochEndOffset=$epochEndOffset.")
                 return forceComplete()
               } else if (epochEndOffset.leaderEpoch < fetchEpoch || epochEndOffset.endOffset < fetchStatus.fetchInfo.fetchOffset) {
-                debug(s"Satisfying fetch $fetchMetadata since it has diverging epoch requiring truncation for partition " +
+                debug(s"Satisfying fetch $this since it has diverging epoch requiring truncation for partition " +
                   s"$topicIdPartition epochEndOffset=$epochEndOffset fetchEpoch=$fetchEpoch fetchOffset=${fetchStatus.fetchInfo.fetchOffset}.")
                 return forceComplete()
               }
@@ -139,30 +125,30 @@ class DelayedFetch(delayMs: Long,
           }
         } catch {
           case _: NotLeaderOrFollowerException =>  // Case A or Case B
-            debug(s"Broker is no longer the leader or follower of $topicIdPartition, satisfy $fetchMetadata immediately")
+            debug(s"Broker is no longer the leader or follower of $topicIdPartition, satisfy $this immediately")
             return forceComplete()
           case _: UnknownTopicOrPartitionException => // Case C
-            debug(s"Broker no longer knows of partition $topicIdPartition, satisfy $fetchMetadata immediately")
+            debug(s"Broker no longer knows of partition $topicIdPartition, satisfy $this immediately")
             return forceComplete()
           case _: KafkaStorageException => // Case D
-            debug(s"Partition $topicIdPartition is in an offline log directory, satisfy $fetchMetadata immediately")
+            debug(s"Partition $topicIdPartition is in an offline log directory, satisfy $this immediately")
             return forceComplete()
           case _: FencedLeaderEpochException => // Case E
             debug(s"Broker is the leader of partition $topicIdPartition, but the requested epoch " +
-              s"$fetchLeaderEpoch is fenced by the latest leader epoch, satisfy $fetchMetadata immediately")
+              s"$fetchLeaderEpoch is fenced by the latest leader epoch, satisfy $this immediately")
             return forceComplete()
         }
     }
 
     // Case G
-    if (accumulatedSize >= fetchMetadata.fetchMinBytes)
+    if (accumulatedSize >= params.minBytes)
        forceComplete()
     else
       false
   }
 
   override def onExpiration(): Unit = {
-    if (fetchMetadata.isFromFollower)
+    if (params.isFromFollower)
       DelayedFetchMetrics.followerExpiredRequestMeter.mark()
     else
       DelayedFetchMetrics.consumerExpiredRequestMeter.mark()
@@ -173,18 +159,18 @@ class DelayedFetch(delayMs: Long,
    */
   override def onComplete(): Unit = {
     val logReadResults = replicaManager.readFromLocalLog(
-      replicaId = fetchMetadata.replicaId,
-      fetchOnlyFromLeader = fetchMetadata.fetchOnlyLeader,
-      fetchIsolation = fetchMetadata.fetchIsolation,
-      fetchMaxBytes = fetchMetadata.fetchMaxBytes,
-      hardMaxBytesLimit = fetchMetadata.hardMaxBytesLimit,
-      readPartitionInfo = fetchMetadata.fetchPartitionStatus.map { case (tp, status) => tp -> status.fetchInfo },
-      clientMetadata = clientMetadata,
+      replicaId = params.replicaId,
+      fetchOnlyFromLeader = params.fetchOnlyLeader,
+      fetchIsolation = params.isolation,
+      fetchMaxBytes = params.maxBytes,
+      hardMaxBytesLimit = params.hardMaxBytesLimit,
+      readPartitionInfo = fetchPartitionStatus.map { case (tp, status) => tp -> status.fetchInfo },
+      clientMetadata = params.clientMetadata,
       quota = quota)
 
     val fetchPartitionData = logReadResults.map { case (tp, result) =>
-      val isReassignmentFetch = fetchMetadata.isFromFollower &&
-        replicaManager.isAddingReplica(tp.topicPartition, fetchMetadata.replicaId)
+      val isReassignmentFetch = params.isFromFollower &&
+        replicaManager.isAddingReplica(tp.topicPartition, params.replicaId)
 
       tp -> result.toFetchPartitionData(isReassignmentFetch)
     }
diff --git a/core/src/main/scala/kafka/server/FetchDataInfo.scala b/core/src/main/scala/kafka/server/FetchDataInfo.scala
index f6cf725843..82e8092c10 100644
--- a/core/src/main/scala/kafka/server/FetchDataInfo.scala
+++ b/core/src/main/scala/kafka/server/FetchDataInfo.scala
@@ -17,15 +17,67 @@
 
 package kafka.server
 
+import kafka.api.Request
+import org.apache.kafka.common.IsolationLevel
 import org.apache.kafka.common.message.FetchResponseData
 import org.apache.kafka.common.record.Records
+import org.apache.kafka.common.replica.ClientMetadata
+import org.apache.kafka.common.requests.FetchRequest
 
 sealed trait FetchIsolation
 case object FetchLogEnd extends FetchIsolation
 case object FetchHighWatermark extends FetchIsolation
 case object FetchTxnCommitted extends FetchIsolation
 
-case class FetchDataInfo(fetchOffsetMetadata: LogOffsetMetadata,
-                         records: Records,
-                         firstEntryIncomplete: Boolean = false,
-                         abortedTransactions: Option[List[FetchResponseData.AbortedTransaction]] = None)
+object FetchIsolation {
+  def apply(
+    request: FetchRequest
+  ): FetchIsolation = {
+    apply(request.replicaId, request.isolationLevel)
+  }
+
+  def apply(
+    replicaId: Int,
+    isolationLevel: IsolationLevel
+  ): FetchIsolation = {
+    if (!Request.isConsumer(replicaId))
+      FetchLogEnd
+    else if (isolationLevel == IsolationLevel.READ_COMMITTED)
+      FetchTxnCommitted
+    else
+      FetchHighWatermark
+  }
+}
+
+case class FetchParams(
+  requestVersion: Short,
+  replicaId: Int,
+  maxWaitMs: Long,
+  minBytes: Int,
+  maxBytes: Int,
+  isolation: FetchIsolation,
+  clientMetadata: Option[ClientMetadata]
+) {
+  def isFromFollower: Boolean = Request.isValidBrokerId(replicaId)
+  def isFromConsumer: Boolean = Request.isConsumer(replicaId)
+  def fetchOnlyLeader: Boolean = isFromFollower || (isFromConsumer && clientMetadata.isEmpty)
+  def hardMaxBytesLimit: Boolean = requestVersion <= 2
+
+  override def toString: String = {
+    s"FetchParams(requestVersion=$requestVersion" +
+      s", replicaId=$replicaId" +
+      s", maxWaitMs=$maxWaitMs" +
+      s", minBytes=$minBytes" +
+      s", maxBytes=$maxBytes" +
+      s", isolation=$isolation" +
+      s", clientMetadata= $clientMetadata" +
+      ")"
+  }
+}
+
+case class FetchDataInfo(
+  fetchOffsetMetadata: LogOffsetMetadata,
+  records: Records,
+  firstEntryIncomplete: Boolean = false,
+  abortedTransactions: Option[List[FetchResponseData.AbortedTransaction]] = None
+)
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala
index 84c14e069c..dd3fb2dfea 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -695,18 +695,6 @@ class KafkaApis(val requestChannel: RequestChannel,
       forgottenTopics,
       topicNames)
 
-    val clientMetadata: Option[ClientMetadata] = if (versionId >= 11) {
-      // Fetch API version 11 added preferred replica logic
-      Some(new DefaultClientMetadata(
-        fetchRequest.rackId,
-        clientId,
-        request.context.clientAddress,
-        request.context.principal,
-        request.context.listenerName.value))
-    } else {
-      None
-    }
-
     val erroneous = mutable.ArrayBuffer[(TopicIdPartition, FetchResponseData.PartitionData)]()
     val interesting = mutable.ArrayBuffer[(TopicIdPartition, FetchRequest.PartitionData)]()
     if (fetchRequest.isFromFollower) {
@@ -943,31 +931,49 @@ class KafkaApis(val requestChannel: RequestChannel,
       }
     }
 
-    // for fetch from consumer, cap fetchMaxBytes to the maximum bytes that could be fetched without being throttled given
-    // no bytes were recorded in the recent quota window
-    // trying to fetch more bytes would result in a guaranteed throttling potentially blocking consumer progress
-    val maxQuotaWindowBytes = if (fetchRequest.isFromFollower)
-      Int.MaxValue
-    else
-      quotas.fetch.getMaxValueInQuotaWindow(request.session, clientId).toInt
-
-    val fetchMaxBytes = Math.min(Math.min(fetchRequest.maxBytes, config.fetchMaxBytes), maxQuotaWindowBytes)
-    val fetchMinBytes = Math.min(fetchRequest.minBytes, fetchMaxBytes)
-    if (interesting.isEmpty)
+    if (interesting.isEmpty) {
       processResponseCallback(Seq.empty)
-    else {
+    } else {
+      // for fetch from consumer, cap fetchMaxBytes to the maximum bytes that could be fetched without being throttled given
+      // no bytes were recorded in the recent quota window
+      // trying to fetch more bytes would result in a guaranteed throttling potentially blocking consumer progress
+      val maxQuotaWindowBytes = if (fetchRequest.isFromFollower)
+        Int.MaxValue
+      else
+        quotas.fetch.getMaxValueInQuotaWindow(request.session, clientId).toInt
+
+      val fetchMaxBytes = Math.min(Math.min(fetchRequest.maxBytes, config.fetchMaxBytes), maxQuotaWindowBytes)
+      val fetchMinBytes = Math.min(fetchRequest.minBytes, fetchMaxBytes)
+
+      val clientMetadata: Option[ClientMetadata] = if (versionId >= 11) {
+        // Fetch API version 11 added preferred replica logic
+        Some(new DefaultClientMetadata(
+          fetchRequest.rackId,
+          clientId,
+          request.context.clientAddress,
+          request.context.principal,
+          request.context.listenerName.value))
+      } else {
+        None
+      }
+
+      val params = FetchParams(
+        requestVersion = versionId,
+        replicaId = fetchRequest.replicaId,
+        maxWaitMs = fetchRequest.maxWait,
+        minBytes = fetchMinBytes,
+        maxBytes = fetchMaxBytes,
+        isolation = FetchIsolation(fetchRequest),
+        clientMetadata = clientMetadata
+      )
+
       // call the replica manager to fetch messages from the local replica
       replicaManager.fetchMessages(
-        fetchRequest.maxWait.toLong,
-        fetchRequest.replicaId,
-        fetchMinBytes,
-        fetchMaxBytes,
-        versionId <= 2,
-        interesting,
-        replicationQuota(fetchRequest),
-        processResponseCallback,
-        fetchRequest.isolationLevel,
-        clientMetadata)
+        params = params,
+        fetchInfos = interesting,
+        quota = replicationQuota(fetchRequest),
+        responseCallback = processResponseCallback,
+      )
     }
   }
 
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
index 2ce33c838a..4a6a6e070c 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
@@ -100,17 +100,22 @@ class ReplicaAlterLogDirsThread(name: String,
 
     val fetchData = request.fetchData(topicNames.asJava)
 
+    val fetchParams = FetchParams(
+      requestVersion = request.version,
+      maxWaitMs = 0L, // timeout is 0 so that the callback will be executed immediately
+      replicaId = Request.FutureLocalReplicaId,
+      minBytes = request.minBytes,
+      maxBytes = request.maxBytes,
+      isolation = FetchLogEnd,
+      clientMetadata = None
+    )
+
     replicaMgr.fetchMessages(
-      0L, // timeout is 0 so that the callback will be executed immediately
-      Request.FutureLocalReplicaId,
-      request.minBytes,
-      request.maxBytes,
-      false,
-      fetchData.asScala.toSeq,
-      UnboundedQuota,
-      processResponseCallback,
-      request.isolationLevel,
-      None)
+      params = fetchParams,
+      fetchInfos = fetchData.asScala.toSeq,
+      quota = UnboundedQuota,
+      responseCallback = processResponseCallback
+    )
 
     if (partitionData == null)
       throw new IllegalStateException(s"Failed to fetch data for partitions ${fetchData.keySet().toArray.mkString(",")}")
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 1c97671c26..e84abbe5f4 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -29,7 +29,6 @@ import kafka.common.RecordValidationException
 import kafka.controller.{KafkaController, StateChangeLogger}
 import kafka.log._
 import kafka.metrics.KafkaMetricsGroup
-import kafka.server.{FetchMetadata => SFetchMetadata}
 import kafka.server.HostedPartition.Online
 import kafka.server.QuotaFactory.QuotaManagers
 import kafka.server.checkpoints.{LazyOffsetCheckpoints, OffsetCheckpointFile, OffsetCheckpoints}
@@ -989,38 +988,24 @@ class ReplicaManager(val config: KafkaConfig,
    * the callback function will be triggered either when timeout or required fetch info is satisfied.
    * Consumers may fetch from any replica, but followers can only fetch from the leader.
    */
-  def fetchMessages(timeout: Long,
-                    replicaId: Int,
-                    fetchMinBytes: Int,
-                    fetchMaxBytes: Int,
-                    hardMaxBytesLimit: Boolean,
-                    fetchInfos: Seq[(TopicIdPartition, PartitionData)],
-                    quota: ReplicaQuota,
-                    responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit,
-                    isolationLevel: IsolationLevel,
-                    clientMetadata: Option[ClientMetadata]): Unit = {
-    val isFromFollower = Request.isValidBrokerId(replicaId)
-    val isFromConsumer = !(isFromFollower || replicaId == Request.FutureLocalReplicaId)
-    val fetchIsolation = if (!isFromConsumer)
-      FetchLogEnd
-    else if (isolationLevel == IsolationLevel.READ_COMMITTED)
-      FetchTxnCommitted
-    else
-      FetchHighWatermark
-
+  def fetchMessages(
+    params: FetchParams,
+    fetchInfos: Seq[(TopicIdPartition, PartitionData)],
+    quota: ReplicaQuota,
+    responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit
+  ): Unit = {
     // Restrict fetching to leader if request is from follower or from a client with older version (no ClientMetadata)
-    val fetchOnlyFromLeader = isFromFollower || (isFromConsumer && clientMetadata.isEmpty)
     def readFromLog(): Seq[(TopicIdPartition, LogReadResult)] = {
       val result = readFromLocalLog(
-        replicaId = replicaId,
-        fetchOnlyFromLeader = fetchOnlyFromLeader,
-        fetchIsolation = fetchIsolation,
-        fetchMaxBytes = fetchMaxBytes,
-        hardMaxBytesLimit = hardMaxBytesLimit,
+        replicaId = params.replicaId,
+        fetchOnlyFromLeader = params.fetchOnlyLeader,
+        fetchIsolation = params.isolation,
+        fetchMaxBytes = params.maxBytes,
+        hardMaxBytesLimit = params.hardMaxBytesLimit,
         readPartitionInfo = fetchInfos,
         quota = quota,
-        clientMetadata = clientMetadata)
-      if (isFromFollower) updateFollowerFetchState(replicaId, result)
+        clientMetadata = params.clientMetadata)
+      if (params.isFromFollower) updateFollowerFetchState(params.replicaId, result)
       else result
     }
 
@@ -1051,10 +1036,10 @@ class ReplicaManager(val config: KafkaConfig,
     //                        4) some error happens while reading data
     //                        5) we found a diverging epoch
     //                        6) has a preferred read replica
-    if (timeout <= 0 || fetchInfos.isEmpty || bytesReadable >= fetchMinBytes || errorReadingData ||
+    if (params.maxWaitMs <= 0 || fetchInfos.isEmpty || bytesReadable >= params.minBytes || errorReadingData ||
       hasDivergingEpoch || hasPreferredReadReplica) {
       val fetchPartitionData = logReadResults.map { case (tp, result) =>
-        val isReassignmentFetch = isFromFollower && isAddingReplica(tp.topicPartition, replicaId)
+        val isReassignmentFetch = params.isFromFollower && isAddingReplica(tp.topicPartition, params.replicaId)
         tp -> result.toFetchPartitionData(isReassignmentFetch)
       }
       responseCallback(fetchPartitionData)
@@ -1067,10 +1052,13 @@ class ReplicaManager(val config: KafkaConfig,
           fetchPartitionStatus += (topicIdPartition -> FetchPartitionStatus(logOffsetMetadata, partitionData))
         })
       }
-      val fetchMetadata: SFetchMetadata = SFetchMetadata(fetchMinBytes, fetchMaxBytes, hardMaxBytesLimit,
-        fetchOnlyFromLeader, fetchIsolation, isFromFollower, replicaId, fetchPartitionStatus)
-      val delayedFetch = new DelayedFetch(timeout, fetchMetadata, this, quota, clientMetadata,
-        responseCallback)
+      val delayedFetch = new DelayedFetch(
+        params = params,
+        fetchPartitionStatus = fetchPartitionStatus,
+        replicaManager = this,
+        quota = quota,
+        responseCallback = responseCallback
+      )
 
       // create a list of (topic, partition) pairs to use as keys for this delayed fetch operation
       val delayedFetchKeys = fetchPartitionStatus.map { case (tp, _) => TopicPartitionOperationKey(tp) }
diff --git a/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala b/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala
index 581af29bec..940968f411 100644
--- a/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala
+++ b/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala
@@ -17,13 +17,14 @@
 package kafka.server
 
 import java.util.Optional
+
 import scala.collection.Seq
 import kafka.cluster.Partition
 import kafka.log.LogOffsetSnapshot
 import org.apache.kafka.common.{TopicIdPartition, Uuid}
 import org.apache.kafka.common.errors.{FencedLeaderEpochException, NotLeaderOrFollowerException}
 import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
-import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.requests.FetchRequest
 import org.junit.jupiter.api.Test
@@ -47,7 +48,7 @@ class DelayedFetchTest {
     val fetchStatus = FetchPartitionStatus(
       startOffsetMetadata = LogOffsetMetadata(fetchOffset),
       fetchInfo = new FetchRequest.PartitionData(Uuid.ZERO_UUID, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch))
-    val fetchMetadata = buildFetchMetadata(replicaId, topicIdPartition, fetchStatus)
+    val fetchParams = buildFollowerFetchParams(replicaId, maxWaitMs = 500)
 
     var fetchResultOpt: Option[FetchPartitionData] = None
     def callback(responses: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
@@ -55,12 +56,12 @@ class DelayedFetchTest {
     }
 
     val delayedFetch = new DelayedFetch(
-      delayMs = 500,
-      fetchMetadata = fetchMetadata,
+      params = fetchParams,
+      fetchPartitionStatus = Seq(topicIdPartition -> fetchStatus),
       replicaManager = replicaManager,
       quota = replicaQuota,
-      clientMetadata = None,
-      responseCallback = callback)
+      responseCallback = callback
+    )
 
     val partition: Partition = mock(classOf[Partition])
 
@@ -93,7 +94,7 @@ class DelayedFetchTest {
     val fetchStatus = FetchPartitionStatus(
       startOffsetMetadata = LogOffsetMetadata(fetchOffset),
       fetchInfo = new FetchRequest.PartitionData(Uuid.ZERO_UUID, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch))
-    val fetchMetadata = buildFetchMetadata(replicaId, topicIdPartition, fetchStatus)
+    val fetchParams = buildFollowerFetchParams(replicaId, maxWaitMs = 500)
 
     var fetchResultOpt: Option[FetchPartitionData] = None
     def callback(responses: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
@@ -101,12 +102,12 @@ class DelayedFetchTest {
     }
 
     val delayedFetch = new DelayedFetch(
-      delayMs = 500,
-      fetchMetadata = fetchMetadata,
+      params = fetchParams,
+      fetchPartitionStatus = Seq(topicIdPartition -> fetchStatus),
       replicaManager = replicaManager,
       quota = replicaQuota,
-      clientMetadata = None,
-      responseCallback = callback)
+      responseCallback = callback
+    )
 
     when(replicaManager.getPartitionOrException(topicIdPartition.topicPartition))
       .thenThrow(new NotLeaderOrFollowerException(s"Replica for $topicIdPartition not available"))
@@ -130,7 +131,7 @@ class DelayedFetchTest {
     val fetchStatus = FetchPartitionStatus(
       startOffsetMetadata = LogOffsetMetadata(fetchOffset),
       fetchInfo = new FetchRequest.PartitionData(topicIdPartition.topicId, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch, lastFetchedEpoch))
-    val fetchMetadata = buildFetchMetadata(replicaId, topicIdPartition, fetchStatus)
+    val fetchParams = buildFollowerFetchParams(replicaId, maxWaitMs = 500)
 
     var fetchResultOpt: Option[FetchPartitionData] = None
     def callback(responses: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
@@ -138,12 +139,12 @@ class DelayedFetchTest {
     }
 
     val delayedFetch = new DelayedFetch(
-      delayMs = 500,
-      fetchMetadata = fetchMetadata,
+      params = fetchParams,
+      fetchPartitionStatus = Seq(topicIdPartition -> fetchStatus),
       replicaManager = replicaManager,
       quota = replicaQuota,
-      clientMetadata = None,
-      responseCallback = callback)
+      responseCallback = callback
+    )
 
     val partition: Partition = mock(classOf[Partition])
     when(replicaManager.getPartitionOrException(topicIdPartition.topicPartition)).thenReturn(partition)
@@ -166,17 +167,19 @@ class DelayedFetchTest {
     assertTrue(fetchResultOpt.isDefined)
   }
 
-  private def buildFetchMetadata(replicaId: Int,
-                                 topicIdPartition: TopicIdPartition,
-                                 fetchStatus: FetchPartitionStatus): FetchMetadata = {
-    FetchMetadata(fetchMinBytes = 1,
-      fetchMaxBytes = maxBytes,
-      hardMaxBytesLimit = false,
-      fetchOnlyLeader = true,
-      fetchIsolation = FetchLogEnd,
-      isFromFollower = true,
+  private def buildFollowerFetchParams(
+    replicaId: Int,
+    maxWaitMs: Int
+  ): FetchParams = {
+    FetchParams(
+      requestVersion = ApiKeys.FETCH.latestVersion,
       replicaId = replicaId,
-      fetchPartitionStatus = Seq((topicIdPartition, fetchStatus)))
+      maxWaitMs = maxWaitMs,
+      minBytes = 1,
+      maxBytes = maxBytes,
+      isolation = FetchLogEnd,
+      clientMetadata = None
+    )
   }
 
   private def expectReadFromReplica(replicaId: Int,
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index 972fc21795..de8bfabc25 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -71,7 +71,6 @@ import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.quota.{ClientQuotaAlteration, ClientQuotaEntity}
 import org.apache.kafka.common.record.FileRecords.TimestampAndOffset
 import org.apache.kafka.common.record._
-import org.apache.kafka.common.replica.ClientMetadata
 import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType
 import org.apache.kafka.common.requests.MetadataResponse.TopicMetadata
 import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
@@ -2358,12 +2357,13 @@ class KafkaApisTest {
 
     when(replicaManager.getLogConfig(ArgumentMatchers.eq(tp))).thenReturn(None)
 
-    when(replicaManager.fetchMessages(anyLong, anyInt, anyInt, anyInt, anyBoolean,
-      any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]], any[ReplicaQuota],
-      any[Seq[(TopicIdPartition, FetchPartitionData)] => Unit](), any[IsolationLevel],
-      any[Option[ClientMetadata]])
-    ).thenAnswer(invocation => {
-      val callback = invocation.getArgument(7).asInstanceOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]
+    when(replicaManager.fetchMessages(
+      any[FetchParams],
+      any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]],
+      any[ReplicaQuota],
+      any[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]()
+    )).thenAnswer(invocation => {
+      val callback = invocation.getArgument(3).asInstanceOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]
       val records = MemoryRecords.withRecords(CompressionType.NONE,
         new SimpleRecord(timestamp, "foo".getBytes(StandardCharsets.UTF_8)))
       callback(Seq(tidp -> FetchPartitionData(Errors.NONE, hw, 0, records,
@@ -2946,12 +2946,13 @@ class KafkaApisTest {
 
     val records = MemoryRecords.withRecords(CompressionType.NONE,
       new SimpleRecord(1000, "foo".getBytes(StandardCharsets.UTF_8)))
-    when(replicaManager.fetchMessages(anyLong, anyInt, anyInt, anyInt, anyBoolean,
-      any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]], any[ReplicaQuota],
-      any[Seq[(TopicIdPartition, FetchPartitionData)] => Unit](), any[IsolationLevel],
-      any[Option[ClientMetadata]])
-    ).thenAnswer(invocation => {
-      val callback = invocation.getArgument(7).asInstanceOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]
+    when(replicaManager.fetchMessages(
+      any[FetchParams],
+      any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]],
+      any[ReplicaQuota],
+      any[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]()
+    )).thenAnswer(invocation => {
+      val callback = invocation.getArgument(3).asInstanceOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]
       callback(Seq(tidp0 -> FetchPartitionData(Errors.NONE, hw, 0, records,
         None, None, None, Option.empty, isReassignmentFetch = isReassigning)))
     })
@@ -2978,7 +2979,6 @@ class KafkaApisTest {
     else
       assertEquals(0, brokerTopicStats.allTopicsStats.reassignmentBytesOutPerSec.get.count())
     assertEquals(records.sizeInBytes(), brokerTopicStats.allTopicsStats.replicationBytesOutRate.get.count())
-
   }
 
   @Test
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
index bf2671a16e..2c93a2c2af 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
@@ -31,10 +31,10 @@ import org.apache.kafka.common.message.UpdateMetadataRequestData
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.requests.{FetchRequest, UpdateMetadataRequest}
-import org.apache.kafka.common.{IsolationLevel, TopicIdPartition, TopicPartition, Uuid}
+import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
-import org.mockito.ArgumentMatchers.{any, anyBoolean, anyInt, anyLong}
+import org.mockito.ArgumentMatchers.{any, anyBoolean}
 import org.mockito.Mockito.{doNothing, mock, never, times, verify, when}
 import org.mockito.{ArgumentCaptor, ArgumentMatchers, Mockito}
 
@@ -271,18 +271,26 @@ class ReplicaAlterLogDirsThreadTest {
                                       responseData: FetchPartitionData): Unit = {
     val callbackCaptor: ArgumentCaptor[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] =
       ArgumentCaptor.forClass(classOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit])
+
+    val expectedFetchParams = FetchParams(
+      requestVersion = ApiKeys.FETCH.latestVersion,
+      replicaId = Request.FutureLocalReplicaId,
+      maxWaitMs = 0L,
+      minBytes = 0,
+      maxBytes = config.replicaFetchResponseMaxBytes,
+      isolation = FetchLogEnd,
+      clientMetadata = None
+    )
+
+    println(expectedFetchParams)
+
     when(replicaManager.fetchMessages(
-      timeout = ArgumentMatchers.eq(0L),
-      replicaId = ArgumentMatchers.eq(Request.FutureLocalReplicaId),
-      fetchMinBytes = ArgumentMatchers.eq(0),
-      fetchMaxBytes = ArgumentMatchers.eq(config.replicaFetchResponseMaxBytes),
-      hardMaxBytesLimit = ArgumentMatchers.eq(false),
+      params = ArgumentMatchers.eq(expectedFetchParams),
       fetchInfos = ArgumentMatchers.eq(Seq(topicIdPartition -> requestData)),
       quota = ArgumentMatchers.eq(UnboundedQuota),
       responseCallback = callbackCaptor.capture(),
-      isolationLevel = ArgumentMatchers.eq(IsolationLevel.READ_UNCOMMITTED),
-      clientMetadata = ArgumentMatchers.eq(None)
     )).thenAnswer(_ => {
+      println("Did we get the callback?")
       callbackCaptor.getValue.apply(Seq((topicIdPartition, responseData)))
     })
   }
@@ -701,16 +709,10 @@ class ReplicaAlterLogDirsThreadTest {
 
     when(replicaManager.logManager).thenReturn(logManager)
     when(replicaManager.fetchMessages(
-      anyLong(),
-      anyInt(),
-      anyInt(),
-      anyInt(),
-      any(),
-      any(),
-      any(),
+      any[FetchParams],
+      any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]],
+      any[ReplicaQuota],
       responseCallback.capture(),
-      any(),
-      any(),
     )).thenAnswer(_ => responseCallback.getValue.apply(Seq.empty[(TopicIdPartition, FetchPartitionData)]))
 
     //Create the thread
@@ -939,16 +941,10 @@ class ReplicaAlterLogDirsThreadTest {
                             responseCallback: ArgumentCaptor[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]): Unit = {
     stub(logT1p0, logT1p1, futureLog, partition, replicaManager)
     when(replicaManager.fetchMessages(
-      anyLong(),
-      anyInt(),
-      anyInt(),
-      anyInt(),
-      any(),
-      any(),
-      any(),
-      responseCallback.capture(),
-      any(),
-      any())
-    ).thenAnswer(_ => responseCallback.getValue.apply(Seq.empty[(TopicIdPartition, FetchPartitionData)]))
+      any[FetchParams],
+      any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]],
+      any[ReplicaQuota],
+      responseCallback.capture()
+    )).thenAnswer(_ => responseCallback.getValue.apply(Seq.empty[(TopicIdPartition, FetchPartitionData)]))
   }
 }
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala
index d281788ab8..df95f701c5 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala
@@ -28,7 +28,7 @@ import kafka.utils.TestUtils.waitUntilTrue
 import kafka.utils.{MockTime, ShutdownableThread, TestUtils}
 import org.apache.kafka.common.metadata.{PartitionChangeRecord, PartitionRecord, TopicRecord}
 import org.apache.kafka.common.metrics.Metrics
-import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.SimpleRecord
 import org.apache.kafka.common.replica.ClientMetadata.DefaultClientMetadata
 import org.apache.kafka.common.requests.{FetchRequest, ProduceResponse}
@@ -224,17 +224,21 @@ class ReplicaManagerConcurrencyTest {
         }
       }
 
-      replicaManager.fetchMessages(
-        timeout = random.nextInt(100),
+      val fetchParams = FetchParams(
+        requestVersion = ApiKeys.FETCH.latestVersion,
         replicaId = replicaId,
-        fetchMinBytes = 1,
-        fetchMaxBytes = 1024 * 1024,
-        hardMaxBytesLimit = false,
+        maxWaitMs = random.nextInt(100),
+        minBytes = 1,
+        maxBytes = 1024 * 1024,
+        isolation = FetchIsolation(replicaId, IsolationLevel.READ_UNCOMMITTED),
+        clientMetadata = Some(clientMetadata)
+      )
+
+      replicaManager.fetchMessages(
+        params = fetchParams,
         fetchInfos = Seq(topicIdPartition -> partitionData),
         quota = QuotaFactory.UnboundedQuota,
         responseCallback = fetchCallback,
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
-        clientMetadata = Some(clientMetadata)
       )
 
       val fetchResult = future.get()
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
index d7b69bc61d..ea03b8d2fe 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
@@ -18,11 +18,13 @@ package kafka.server
 
 import java.io.File
 import java.util.{Collections, Optional, Properties}
+
 import kafka.cluster.Partition
 import kafka.log.{LogManager, LogOffsetSnapshot, UnifiedLog}
 import kafka.server.QuotaFactory.QuotaManagers
 import kafka.utils._
 import org.apache.kafka.common.metrics.Metrics
+import org.apache.kafka.common.protocol.ApiKeys
 import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord}
 import org.apache.kafka.common.requests.FetchRequest
 import org.apache.kafka.common.requests.FetchRequest.PartitionData
@@ -205,17 +207,23 @@ class ReplicaManagerQuotasTest {
       val tp = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("t1", 0))
       val fetchPartitionStatus = FetchPartitionStatus(LogOffsetMetadata(messageOffset = 50L, segmentBaseOffset = 0L,
          relativePositionInSegment = 250), new PartitionData(Uuid.ZERO_UUID, 50, 0, 1, Optional.empty()))
-      val fetchMetadata = FetchMetadata(fetchMinBytes = 1,
-        fetchMaxBytes = 1000,
-        hardMaxBytesLimit = true,
-        fetchOnlyLeader = true,
-        fetchIsolation = FetchLogEnd,
-        isFromFollower = true,
+      val fetchParams = FetchParams(
+        requestVersion = ApiKeys.FETCH.latestVersion,
         replicaId = 1,
-        fetchPartitionStatus = List((tp, fetchPartitionStatus))
+        maxWaitMs = 600,
+        minBytes = 1,
+        maxBytes = 1000,
+        isolation = FetchLogEnd,
+        clientMetadata = None
       )
-      new DelayedFetch(delayMs = 600, fetchMetadata = fetchMetadata, replicaManager = replicaManager,
-        quota = null, clientMetadata = None, responseCallback = null) {
+
+      new DelayedFetch(
+        params = fetchParams,
+        fetchPartitionStatus = Seq(tp -> fetchPartitionStatus),
+        replicaManager = replicaManager,
+        quota = null,
+        responseCallback = null
+      ) {
         override def forceComplete(): Boolean = true
       }
     }
@@ -249,17 +257,23 @@ class ReplicaManagerQuotasTest {
       val tidp = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("t1", 0))
       val fetchPartitionStatus = FetchPartitionStatus(LogOffsetMetadata(messageOffset = 50L, segmentBaseOffset = 0L,
         relativePositionInSegment = 250), new PartitionData(Uuid.ZERO_UUID, 50, 0, 1, Optional.empty()))
-      val fetchMetadata = FetchMetadata(fetchMinBytes = 1,
-        fetchMaxBytes = 1000,
-        hardMaxBytesLimit = true,
-        fetchOnlyLeader = true,
-        fetchIsolation = FetchLogEnd,
-        isFromFollower = false,
+      val fetchParams = FetchParams(
+        requestVersion = ApiKeys.FETCH.latestVersion,
         replicaId = FetchRequest.CONSUMER_REPLICA_ID,
-        fetchPartitionStatus = List((tidp, fetchPartitionStatus))
+        maxWaitMs = 600,
+        minBytes = 1,
+        maxBytes = 1000,
+        isolation = FetchHighWatermark,
+        clientMetadata = None
       )
-      new DelayedFetch(delayMs = 600, fetchMetadata = fetchMetadata, replicaManager = replicaManager,
-        quota = null, clientMetadata = None, responseCallback = null) {
+
+      new DelayedFetch(
+        params = fetchParams,
+        fetchPartitionStatus = Seq(tidp -> fetchPartitionStatus),
+        replicaManager = replicaManager,
+        quota = null,
+        responseCallback = null
+      ) {
         override def forceComplete(): Boolean = true
       }
     }
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index ac2fc9926d..aa28ce7269 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -269,7 +269,7 @@ class ReplicaManagerTest {
         Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build()
       rm.becomeLeaderOrFollower(1, leaderAndIsrRequest2, (_, _) => ())
 
-      assertTrue(appendResult.isFired)
+      assertTrue(appendResult.hasFired)
     } finally {
       rm.shutdown(checkpointHW = false)
     }
@@ -508,12 +508,15 @@ class ReplicaManagerTest {
       }
 
       // fetch as follower to advance the high watermark
-      fetchAsFollower(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+      fetchPartitionAsFollower(
+        replicaManager,
+        new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
         new PartitionData(Uuid.ZERO_UUID, numRecords, 0, 100000, Optional.empty()),
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED)
+        replicaId = 1
+      )
 
       // fetch should return empty since LSO should be stuck at 0
-      var consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+      var consumerFetchResult = fetchPartitionAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
         new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         isolationLevel = IsolationLevel.READ_COMMITTED)
       var fetchData = consumerFetchResult.assertFired
@@ -523,10 +526,15 @@ class ReplicaManagerTest {
       assertEquals(Some(List.empty[FetchResponseData.AbortedTransaction]), fetchData.abortedTransactions)
 
       // delayed fetch should timeout and return nothing
-      consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+      consumerFetchResult = fetchPartitionAsConsumer(
+        replicaManager,
+        new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
         new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
-        isolationLevel = IsolationLevel.READ_COMMITTED, minBytes = 1000)
-      assertFalse(consumerFetchResult.isFired)
+        isolationLevel = IsolationLevel.READ_COMMITTED,
+        minBytes = 1000,
+        maxWaitMs = 1000
+      )
+      assertFalse(consumerFetchResult.hasFired)
       timer.advanceClock(1001)
 
       fetchData = consumerFetchResult.assertFired
@@ -544,21 +552,27 @@ class ReplicaManagerTest {
 
       // the LSO has advanced, but the appended commit marker has not been replicated, so
       // none of the data from the transaction should be visible yet
-      consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+      consumerFetchResult = fetchPartitionAsConsumer(
+        replicaManager,
+        new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
         new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
-        isolationLevel = IsolationLevel.READ_COMMITTED)
+        isolationLevel = IsolationLevel.READ_COMMITTED
+      )
 
       fetchData = consumerFetchResult.assertFired
       assertEquals(Errors.NONE, fetchData.error)
       assertTrue(fetchData.records.batches.asScala.isEmpty)
 
       // fetch as follower to advance the high watermark
-      fetchAsFollower(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+      fetchPartitionAsFollower(
+        replicaManager,
+        new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
         new PartitionData(Uuid.ZERO_UUID, numRecords + 1, 0, 100000, Optional.empty()),
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED)
+        replicaId = 1
+      )
 
       // now all of the records should be fetchable
-      consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+      consumerFetchResult = fetchPartitionAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
         new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         isolationLevel = IsolationLevel.READ_COMMITTED)
 
@@ -622,16 +636,24 @@ class ReplicaManagerTest {
         .onFire { response => assertEquals(Errors.NONE, response.error) }
 
       // fetch as follower to advance the high watermark
-      fetchAsFollower(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+      fetchPartitionAsFollower(
+        replicaManager,
+        new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
         new PartitionData(Uuid.ZERO_UUID, numRecords + 1, 0, 100000, Optional.empty()),
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED)
+        replicaId = 1
+      )
 
       // Set the minBytes in order force this request to enter purgatory. When it returns, we should still
       // see the newly aborted transaction.
-      val fetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+      val fetchResult = fetchPartitionAsConsumer(
+        replicaManager,
+        new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
         new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
-        isolationLevel = IsolationLevel.READ_COMMITTED, minBytes = 10000)
-      assertFalse(fetchResult.isFired)
+        isolationLevel = IsolationLevel.READ_COMMITTED,
+        minBytes = 10000,
+        maxWaitMs = 1000
+      )
+      assertFalse(fetchResult.hasFired)
 
       timer.advanceClock(1001)
       val fetchData = fetchResult.assertFired
@@ -687,8 +709,12 @@ class ReplicaManagerTest {
       }
 
       // Followers are always allowed to fetch above the high watermark
-      val followerFetchResult = fetchAsFollower(rm, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
-        new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty()))
+      val followerFetchResult = fetchPartitionAsFollower(
+        rm,
+        new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty()),
+        replicaId = 1
+      )
       val followerFetchData = followerFetchResult.assertFired
       assertEquals(Errors.NONE, followerFetchData.error, "Should not give an exception")
       assertTrue(followerFetchData.records.batches.iterator.hasNext, "Should return some data")
@@ -696,7 +722,7 @@ class ReplicaManagerTest {
       // Consumers are not allowed to consume above the high watermark. However, since the
       // high watermark could be stale at the time of the request, we do not return an out of
       // range error and instead return an empty record set.
-      val consumerFetchResult = fetchAsConsumer(rm, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+      val consumerFetchResult = fetchPartitionAsConsumer(rm, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
         new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty()))
       val consumerFetchData = consumerFetchResult.assertFired
       assertEquals(Errors.NONE, consumerFetchData.error, "Should not give an exception")
@@ -753,51 +779,34 @@ class ReplicaManagerTest {
       }
 
       // We receive one valid request from the follower and replica state is updated
-      var successfulFetch: Option[FetchPartitionData] = None
-      def callback(response: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
-        successfulFetch = response.headOption.filter(_._1 == tidp).map(_._2)
-      }
-
       val validFetchPartitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, maxFetchBytes,
         Optional.of(leaderEpoch))
 
-      replicaManager.fetchMessages(
-        timeout = 0L,
-        replicaId = 1,
-        fetchMinBytes = 1,
-        fetchMaxBytes = maxFetchBytes,
-        hardMaxBytesLimit = false,
-        fetchInfos = Seq(tidp -> validFetchPartitionData),
-        quota = UnboundedQuota,
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
-        responseCallback = callback,
-        clientMetadata = None
+      val validFetchResult = fetchPartitionAsFollower(
+        replicaManager,
+        tidp,
+        validFetchPartitionData,
+        replicaId = 1
       )
 
-      assertTrue(successfulFetch.isDefined)
+      assertEquals(Errors.NONE, validFetchResult.assertFired.error)
       assertEquals(0L, followerReplica.stateSnapshot.logStartOffset)
       assertEquals(0L, followerReplica.stateSnapshot.logEndOffset)
 
-
       // Next we receive an invalid request with a higher fetch offset, but an old epoch.
       // We expect that the replica state does not get updated.
       val invalidFetchPartitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 3L, 0L, maxFetchBytes,
         Optional.of(leaderEpoch - 1))
 
-      replicaManager.fetchMessages(
-        timeout = 0L,
-        replicaId = 1,
-        fetchMinBytes = 1,
-        fetchMaxBytes = maxFetchBytes,
-        hardMaxBytesLimit = false,
-        fetchInfos = Seq(tidp -> invalidFetchPartitionData),
-        quota = UnboundedQuota,
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
-        responseCallback = callback,
-        clientMetadata = None
+
+      val invalidFetchResult = fetchPartitionAsFollower(
+        replicaManager,
+        tidp,
+        invalidFetchPartitionData,
+        replicaId = 1
       )
 
-      assertTrue(successfulFetch.isDefined)
+      assertEquals(Errors.FENCED_LEADER_EPOCH, invalidFetchResult.assertFired.error)
       assertEquals(0L, followerReplica.stateSnapshot.logStartOffset)
       assertEquals(0L, followerReplica.stateSnapshot.logEndOffset)
 
@@ -806,23 +815,17 @@ class ReplicaManagerTest {
       val divergingFetchPartitionData = new FetchRequest.PartitionData(tidp.topicId, 3L, 0L, maxFetchBytes,
         Optional.of(leaderEpoch), Optional.of(leaderEpoch - 1))
 
-      replicaManager.fetchMessages(
-        timeout = 0L,
-        replicaId = 1,
-        fetchMinBytes = 1,
-        fetchMaxBytes = maxFetchBytes,
-        hardMaxBytesLimit = false,
-        fetchInfos = Seq(tidp -> divergingFetchPartitionData),
-        quota = UnboundedQuota,
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
-        responseCallback = callback,
-        clientMetadata = None
+      val divergingEpochResult = fetchPartitionAsFollower(
+        replicaManager,
+        tidp,
+        divergingFetchPartitionData,
+        replicaId = 1
       )
 
-      assertTrue(successfulFetch.isDefined)
+      assertEquals(Errors.NONE, divergingEpochResult.assertFired.error)
+      assertTrue(divergingEpochResult.assertFired.divergingEpoch.isDefined)
       assertEquals(0L, followerReplica.stateSnapshot.logStartOffset)
       assertEquals(0L, followerReplica.stateSnapshot.logEndOffset)
-
     } finally {
       replicaManager.shutdown(checkpointHW = false)
     }
@@ -871,18 +874,14 @@ class ReplicaManagerTest {
       def callback(response: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
         successfulFetch = response
       }
-      replicaManager.fetchMessages(
-        timeout = 0L,
+
+      fetchPartitions(
+        replicaManager,
         replicaId = 1,
-        fetchMinBytes = 1,
-        fetchMaxBytes = maxFetchBytes,
-        hardMaxBytesLimit = false,
         fetchInfos = Seq(inconsistentTidp -> validFetchPartitionData),
-        quota = UnboundedQuota,
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
-        responseCallback = callback,
-        clientMetadata = None
+        responseCallback = callback
       )
+
       val fetch1 = successfulFetch.headOption.filter(_._1 == inconsistentTidp).map(_._2)
       assertTrue(fetch1.isDefined)
       assertEquals(Errors.INCONSISTENT_TOPIC_ID, fetch1.get.error)
@@ -891,17 +890,11 @@ class ReplicaManagerTest {
       // Fetch messages simulating an ID in the log.
       // We should not see topic ID errors.
       val zeroTidp = new TopicIdPartition(Uuid.ZERO_UUID, tidp.topicPartition)
-      replicaManager.fetchMessages(
-        timeout = 0L,
+      fetchPartitions(
+        replicaManager,
         replicaId = 1,
-        fetchMinBytes = 1,
-        fetchMaxBytes = maxFetchBytes,
-        hardMaxBytesLimit = false,
         fetchInfos = Seq(zeroTidp -> validFetchPartitionData),
-        quota = UnboundedQuota,
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
-        responseCallback = callback,
-        clientMetadata = None
+        responseCallback = callback
       )
       val fetch2 = successfulFetch.headOption.filter(_._1 == zeroTidp).map(_._2)
       assertTrue(fetch2.isDefined)
@@ -932,17 +925,11 @@ class ReplicaManagerTest {
       assertEquals(None, replicaManager.getPartitionOrException(tp2).topicId)
 
       // Fetch messages simulating the request containing a topic ID. We should not have an error.
-      replicaManager.fetchMessages(
-        timeout = 0L,
+      fetchPartitions(
+        replicaManager,
         replicaId = 1,
-        fetchMinBytes = 1,
-        fetchMaxBytes = maxFetchBytes,
-        hardMaxBytesLimit = false,
         fetchInfos = Seq(tidp2 -> validFetchPartitionData),
-        quota = UnboundedQuota,
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
-        responseCallback = callback,
-        clientMetadata = None
+        responseCallback = callback
       )
       val fetch3 = successfulFetch.headOption.filter(_._1 == tidp2).map(_._2)
       assertTrue(fetch3.isDefined)
@@ -950,17 +937,11 @@ class ReplicaManagerTest {
 
       // Fetch messages simulating the request not containing a topic ID. We should not have an error.
       val zeroTidp2 = new TopicIdPartition(Uuid.ZERO_UUID, tidp2.topicPartition)
-      replicaManager.fetchMessages(
-        timeout = 0L,
+      fetchPartitions(
+        replicaManager,
         replicaId = 1,
-        fetchMinBytes = 1,
-        fetchMaxBytes = maxFetchBytes,
-        hardMaxBytesLimit = false,
         fetchInfos = Seq(zeroTidp2 -> validFetchPartitionData),
-        quota = UnboundedQuota,
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
-        responseCallback = callback,
-        clientMetadata = None
+        responseCallback = callback
       )
       val fetch4 = successfulFetch.headOption.filter(_._1 == zeroTidp2).map(_._2)
       assertTrue(fetch4.isDefined)
@@ -1051,20 +1032,19 @@ class ReplicaManagerTest {
         assertFalse(tp1Status.get.records.batches.iterator.hasNext)
       }
 
-      replicaManager.fetchMessages(
-        timeout = 1000,
+      fetchPartitions(
+        replicaManager,
         replicaId = 1,
-        fetchMinBytes = 0,
-        fetchMaxBytes = Int.MaxValue,
-        hardMaxBytesLimit = false,
         fetchInfos = Seq(
           tidp0 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty()),
-          tidp1 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty())),
-        quota = UnboundedQuota,
+          tidp1 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty())
+        ),
         responseCallback = fetchCallback,
-        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
-        clientMetadata = None
+        maxWaitMs = 1000,
+        minBytes = 0,
+        maxBytes = Int.MaxValue
       )
+
       val tp0Log = replicaManager.localLog(tp0)
       assertTrue(tp0Log.isDefined)
       assertEquals(1, tp0Log.get.highWatermark, "hw should be incremented")
@@ -1228,12 +1208,12 @@ class ReplicaManagerTest {
       val metadata: ClientMetadata = new DefaultClientMetadata("rack-a", "client-id",
         InetAddress.getByName("localhost"), KafkaPrincipal.ANONYMOUS, "default")
 
-      val consumerResult = fetchAsConsumer(replicaManager, tidp0,
+      val consumerResult = fetchPartitionAsConsumer(replicaManager, tidp0,
         new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         clientMetadata = Some(metadata))
 
       // Fetch from follower succeeds
-      assertTrue(consumerResult.isFired)
+      assertTrue(consumerResult.hasFired)
 
       // But only leader will compute preferred replica
       assertTrue(consumerResult.assertFired.preferredReadReplica.isEmpty)
@@ -1286,12 +1266,12 @@ class ReplicaManagerTest {
       val metadata = new DefaultClientMetadata("rack-a", "client-id",
         InetAddress.getByName("localhost"), KafkaPrincipal.ANONYMOUS, "default")
 
-      val consumerResult = fetchAsConsumer(replicaManager, tidp0,
+      val consumerResult = fetchPartitionAsConsumer(replicaManager, tidp0,
         new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         clientMetadata = Some(metadata))
 
       // Fetch from leader succeeds
-      assertTrue(consumerResult.isFired)
+      assertTrue(consumerResult.hasFired)
 
       // Returns a preferred replica (should just be the leader, which is None)
       assertFalse(consumerResult.assertFired.preferredReadReplica.isDefined)
@@ -1334,12 +1314,12 @@ class ReplicaManagerTest {
       val metadata = new DefaultClientMetadata("rack-a", "client-id",
         InetAddress.getLocalHost, KafkaPrincipal.ANONYMOUS, "default")
 
-      val consumerResult = fetchAsConsumer(replicaManager, tidp0,
+      val consumerResult = fetchPartitionAsConsumer(replicaManager, tidp0,
         new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000,
           Optional.empty()), clientMetadata = Some(metadata))
 
       // Fetch from follower succeeds
-      assertTrue(consumerResult.isFired)
+      assertTrue(consumerResult.hasFired)
 
       // Expect not run the preferred read replica selection
       assertEquals(0, replicaManager.replicaSelectorOpt.get.asInstanceOf[MockReplicaSelector].getSelectionCount)
@@ -1395,12 +1375,12 @@ class ReplicaManagerTest {
         InetAddress.getLocalHost, KafkaPrincipal.ANONYMOUS, "default")
 
       // If a preferred read replica is selected, the fetch response returns immediately, even if min bytes and timeout conditions are not met.
-      val consumerResult = fetchAsConsumer(replicaManager, tidp0,
+      val consumerResult = fetchPartitionAsConsumer(replicaManager, tidp0,
         new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
-        minBytes = 1, clientMetadata = Some(metadata), timeout = 5000)
+        minBytes = 1, clientMetadata = Some(metadata), maxWaitMs = 5000)
 
       // Fetch from leader succeeds
-      assertTrue(consumerResult.isFired)
+      assertTrue(consumerResult.hasFired)
 
       // No delayed fetch was inserted
       assertEquals(0, replicaManager.delayedFetchPurgatory.watched)
@@ -1454,24 +1434,33 @@ class ReplicaManagerTest {
 
     // Increment the hw in the leader by fetching from the last offset
     val fetchOffset = simpleRecords.size
-    var followerResult = fetchAsFollower(replicaManager, tidp0,
+    var followerResult = fetchPartitionAsFollower(
+      replicaManager,
+      tidp0,
       new PartitionData(Uuid.ZERO_UUID, fetchOffset, 0, 100000, Optional.empty()),
-      clientMetadata = None)
-    assertTrue(followerResult.isFired)
+      replicaId = 1,
+      minBytes = 0
+    )
+    assertTrue(followerResult.hasFired)
     assertEquals(0, followerResult.assertFired.highWatermark)
 
-    assertTrue(appendResult.isFired, "Expected producer request to be acked")
+    assertTrue(appendResult.hasFired, "Expected producer request to be acked")
 
     // Fetch from the same offset, no new data is expected and hence the fetch request should
     // go to the purgatory
-    followerResult = fetchAsFollower(replicaManager, tidp0,
+    followerResult = fetchPartitionAsFollower(
+      replicaManager,
+      tidp0,
       new PartitionData(Uuid.ZERO_UUID, fetchOffset, 0, 100000, Optional.empty()),
-      clientMetadata = None, minBytes = 1000)
-    assertFalse(followerResult.isFired, "Request completed immediately unexpectedly")
+      replicaId = 1,
+      minBytes = 1000,
+      maxWaitMs = 1000
+    )
+    assertFalse(followerResult.hasFired, "Request completed immediately unexpectedly")
 
     // Complete the request in the purgatory by advancing the clock
     timer.advanceClock(1001)
-    assertTrue(followerResult.isFired)
+    assertTrue(followerResult.hasFired)
 
     assertEquals(fetchOffset, followerResult.assertFired.highWatermark)
   }
@@ -1545,16 +1534,15 @@ class ReplicaManagerTest {
       val clientMetadata = new DefaultClientMetadata("", "", null, KafkaPrincipal.ANONYMOUS, "")
       var partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
         Optional.of(0))
-      var fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata))
-      assertNotNull(fetchResult.get)
-      assertEquals(Errors.NONE, fetchResult.get.error)
+      var fetchResult = fetchPartitionAsConsumer(replicaManager, tidp0, partitionData,
+        clientMetadata = Some(clientMetadata))
+      assertEquals(Errors.NONE, fetchResult.assertFired.error)
 
       // Fetch from follower, with empty ClientMetadata (which implies an older version)
       partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
         Optional.of(0))
-      fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None)
-      assertNotNull(fetchResult.get)
-      assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchResult.get.error)
+      fetchResult = fetchPartitionAsConsumer(replicaManager, tidp0, partitionData)
+      assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchResult.assertFired.error)
     } finally {
       replicaManager.shutdown()
     }
@@ -1596,16 +1584,14 @@ class ReplicaManagerTest {
     val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
       Optional.empty())
 
-    val nonPurgatoryFetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None)
-    assertNotNull(nonPurgatoryFetchResult.get)
-    assertEquals(Errors.NONE, nonPurgatoryFetchResult.get.error)
+    val nonPurgatoryFetchResult = fetchPartitionAsConsumer(replicaManager, tidp0, partitionData)
+    assertEquals(Errors.NONE, nonPurgatoryFetchResult.assertFired.error)
     assertMetricCount(1)
 
-    val purgatoryFetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 10)
-    assertNull(purgatoryFetchResult.get)
+    val purgatoryFetchResult = fetchPartitionAsConsumer(replicaManager, tidp0, partitionData, maxWaitMs = 10)
+    assertFalse(purgatoryFetchResult.hasFired)
     mockTimer.advanceClock(11)
-    assertNotNull(purgatoryFetchResult.get)
-    assertEquals(Errors.NONE, purgatoryFetchResult.get.error)
+    assertEquals(Errors.NONE, purgatoryFetchResult.assertFired.error)
     assertMetricCount(2)
   }
 
@@ -1638,8 +1624,8 @@ class ReplicaManagerTest {
 
       val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
         Optional.empty())
-      val fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 10)
-      assertNull(fetchResult.get)
+      val fetchResult = fetchPartitionAsConsumer(replicaManager, tidp0, partitionData, maxWaitMs = 10)
+      assertFalse(fetchResult.hasFired)
 
       // Become a follower and ensure that the delayed fetch returns immediately
       val becomeFollowerRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch,
@@ -1656,9 +1642,7 @@ class ReplicaManagerTest {
         topicIds.asJava,
         Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build()
       replicaManager.becomeLeaderOrFollower(0, becomeFollowerRequest, (_, _) => ())
-
-      assertNotNull(fetchResult.get)
-      assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchResult.get.error)
+      assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchResult.assertFired.error)
     } finally {
       replicaManager.shutdown()
     }
@@ -1696,8 +1680,14 @@ class ReplicaManagerTest {
       val clientMetadata = new DefaultClientMetadata("", "", null, KafkaPrincipal.ANONYMOUS, "")
       val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
         Optional.of(1))
-      val fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata), timeout = 10)
-      assertNull(fetchResult.get)
+      val fetchResult = fetchPartitionAsConsumer(
+        replicaManager,
+        tidp0,
+        partitionData,
+        clientMetadata = Some(clientMetadata),
+        maxWaitMs = 10
+      )
+      assertFalse(fetchResult.hasFired)
 
       // Become a follower and ensure that the delayed fetch returns immediately
       val becomeFollowerRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch,
@@ -1714,9 +1704,7 @@ class ReplicaManagerTest {
         topicIds.asJava,
         Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build()
       replicaManager.becomeLeaderOrFollower(0, becomeFollowerRequest, (_, _) => ())
-
-      assertNotNull(fetchResult.get)
-      assertEquals(Errors.FENCED_LEADER_EPOCH, fetchResult.get.error)
+      assertEquals(Errors.FENCED_LEADER_EPOCH, fetchResult.assertFired.error)
     } finally {
       replicaManager.shutdown()
     }
@@ -1752,15 +1740,13 @@ class ReplicaManagerTest {
     val clientMetadata = new DefaultClientMetadata("", "", null, KafkaPrincipal.ANONYMOUS, "")
     var partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
       Optional.of(1))
-    var fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata))
-    assertNotNull(fetchResult.get)
-    assertEquals(Errors.NONE, fetchResult.get.error)
+    var fetchResult = fetchPartitionAsConsumer(replicaManager, tidp0, partitionData, clientMetadata = Some(clientMetadata))
+    assertEquals(Errors.NONE, fetchResult.assertFired.error)
 
     partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
       Optional.empty())
-    fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata))
-    assertNotNull(fetchResult.get)
-    assertEquals(Errors.NONE, fetchResult.get.error)
+    fetchResult = fetchPartitionAsConsumer(replicaManager, tidp0, partitionData, clientMetadata = Some(clientMetadata))
+    assertEquals(Errors.NONE, fetchResult.assertFired.error)
   }
 
   @Test
@@ -1795,8 +1781,8 @@ class ReplicaManagerTest {
 
     val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
       Optional.of(1))
-    val fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 10)
-    assertNull(fetchResult.get)
+    val fetchResult = fetchPartitionAsConsumer(replicaManager, tidp0, partitionData, maxWaitMs = 10)
+    assertFalse(fetchResult.hasFired)
     when(replicaManager.metadataCache.contains(tp0)).thenReturn(true)
 
     // We have a fetch in purgatory, now receive a stop replica request and
@@ -1807,8 +1793,7 @@ class ReplicaManagerTest {
         .setDeletePartition(true)
         .setLeaderEpoch(LeaderAndIsr.EpochDuringDelete)))
 
-    assertNotNull(fetchResult.get)
-    assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchResult.get.error)
+    assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchResult.assertFired.error)
   }
 
   @Test
@@ -1880,30 +1865,6 @@ class ReplicaManagerTest {
     produceResult
   }
 
-  private def sendConsumerFetch(replicaManager: ReplicaManager,
-                                topicIdPartition: TopicIdPartition,
-                                partitionData: FetchRequest.PartitionData,
-                                clientMetadataOpt: Option[ClientMetadata],
-                                timeout: Long = 0L): AtomicReference[FetchPartitionData] = {
-    val fetchResult = new AtomicReference[FetchPartitionData]()
-    def callback(response: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
-      fetchResult.set(response.toMap.apply(topicIdPartition))
-    }
-    replicaManager.fetchMessages(
-      timeout = timeout,
-      replicaId = Request.OrdinaryConsumerId,
-      fetchMinBytes = 1,
-      fetchMaxBytes = 100,
-      hardMaxBytesLimit = false,
-      fetchInfos = Seq(topicIdPartition -> partitionData),
-      quota = UnboundedQuota,
-      isolationLevel = IsolationLevel.READ_UNCOMMITTED,
-      responseCallback = callback,
-      clientMetadata = clientMetadataOpt
-    )
-    fetchResult
-  }
-
   /**
    * This method assumes that the test using created ReplicaManager calls
    * ReplicaManager.becomeLeaderOrFollower() once with LeaderAndIsrRequest containing
@@ -2085,11 +2046,11 @@ class ReplicaManagerTest {
     private var fun: Option[T => Unit] = None
 
     def assertFired: T = {
-      assertTrue(isFired, "Callback has not been fired")
+      assertTrue(hasFired, "Callback has not been fired")
       value.get
     }
 
-    def isFired: Boolean = {
+    def hasFired: Boolean = {
       value.isDefined
     }
 
@@ -2100,7 +2061,7 @@ class ReplicaManagerTest {
 
     def onFire(fun: T => Unit): CallbackResult[T] = {
       this.fun = Some(fun)
-      if (this.isFired) fire(value.get)
+      if (this.hasFired) fire(value.get)
       this
     }
   }
@@ -2128,33 +2089,67 @@ class ReplicaManagerTest {
     result
   }
 
-  private def fetchAsConsumer(replicaManager: ReplicaManager,
-                              partition: TopicIdPartition,
-                              partitionData: PartitionData,
-                              minBytes: Int = 0,
-                              isolationLevel: IsolationLevel = IsolationLevel.READ_UNCOMMITTED,
-                              clientMetadata: Option[ClientMetadata] = None,
-                              timeout: Long = 1000): CallbackResult[FetchPartitionData] = {
-    fetchMessages(replicaManager, replicaId = -1, partition, partitionData, minBytes, isolationLevel, clientMetadata, timeout)
-  }
-
-  private def fetchAsFollower(replicaManager: ReplicaManager,
-                              partition: TopicIdPartition,
-                              partitionData: PartitionData,
-                              minBytes: Int = 0,
-                              isolationLevel: IsolationLevel = IsolationLevel.READ_UNCOMMITTED,
-                              clientMetadata: Option[ClientMetadata] = None): CallbackResult[FetchPartitionData] = {
-    fetchMessages(replicaManager, replicaId = 1, partition, partitionData, minBytes, isolationLevel, clientMetadata)
-  }
-
-  private def fetchMessages(replicaManager: ReplicaManager,
-                            replicaId: Int,
-                            partition: TopicIdPartition,
-                            partitionData: PartitionData,
-                            minBytes: Int,
-                            isolationLevel: IsolationLevel,
-                            clientMetadata: Option[ClientMetadata],
-                            timeout: Long = 1000): CallbackResult[FetchPartitionData] = {
+  private def fetchPartitionAsConsumer(
+    replicaManager: ReplicaManager,
+    partition: TopicIdPartition,
+    partitionData: PartitionData,
+    maxWaitMs: Long = 0,
+    minBytes: Int = 1,
+    maxBytes: Int = 1024 * 1024,
+    isolationLevel: IsolationLevel = IsolationLevel.READ_UNCOMMITTED,
+    clientMetadata: Option[ClientMetadata] = None,
+  ): CallbackResult[FetchPartitionData] = {
+    val isolation = isolationLevel match {
+      case IsolationLevel.READ_COMMITTED => FetchTxnCommitted
+      case IsolationLevel.READ_UNCOMMITTED => FetchHighWatermark
+    }
+
+    fetchPartition(
+      replicaManager,
+      replicaId = Request.OrdinaryConsumerId,
+      partition,
+      partitionData,
+      minBytes,
+      maxBytes,
+      isolation,
+      clientMetadata,
+      maxWaitMs
+    )
+  }
+
+  private def fetchPartitionAsFollower(
+    replicaManager: ReplicaManager,
+    partition: TopicIdPartition,
+    partitionData: PartitionData,
+    replicaId: Int,
+    maxWaitMs: Long = 0,
+    minBytes: Int = 1,
+    maxBytes: Int = 1024 * 1024,
+  ): CallbackResult[FetchPartitionData] = {
+    fetchPartition(
+      replicaManager,
+      replicaId = replicaId,
+      partition,
+      partitionData,
+      minBytes = minBytes,
+      maxBytes = maxBytes,
+      isolation = FetchLogEnd,
+      clientMetadata = None,
+      maxWaitMs = maxWaitMs
+    )
+  }
+
+  private def fetchPartition(
+    replicaManager: ReplicaManager,
+    replicaId: Int,
+    partition: TopicIdPartition,
+    partitionData: PartitionData,
+    minBytes: Int,
+    maxBytes: Int,
+    isolation: FetchIsolation,
+    clientMetadata: Option[ClientMetadata],
+    maxWaitMs: Long
+  ): CallbackResult[FetchPartitionData] = {
     val result = new CallbackResult[FetchPartitionData]()
     def fetchCallback(responseStatus: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
       assertEquals(1, responseStatus.size)
@@ -2163,22 +2158,52 @@ class ReplicaManagerTest {
       result.fire(fetchData)
     }
 
-    replicaManager.fetchMessages(
-      timeout = timeout,
+    fetchPartitions(
+      replicaManager,
       replicaId = replicaId,
-      fetchMinBytes = minBytes,
-      fetchMaxBytes = Int.MaxValue,
-      hardMaxBytesLimit = false,
       fetchInfos = Seq(partition -> partitionData),
-      quota = UnboundedQuota,
       responseCallback = fetchCallback,
-      isolationLevel = isolationLevel,
+      maxWaitMs = maxWaitMs,
+      minBytes = minBytes,
+      maxBytes = maxBytes,
+      isolation = isolation,
       clientMetadata = clientMetadata
     )
 
     result
   }
 
+  private def fetchPartitions(
+    replicaManager: ReplicaManager,
+    replicaId: Int,
+    fetchInfos: Seq[(TopicIdPartition, PartitionData)],
+    responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit,
+    requestVersion: Short = ApiKeys.FETCH.latestVersion,
+    maxWaitMs: Long = 0,
+    minBytes: Int = 1,
+    maxBytes: Int = 1024 * 1024,
+    quota: ReplicaQuota = UnboundedQuota,
+    isolation: FetchIsolation = FetchLogEnd,
+    clientMetadata: Option[ClientMetadata] = None
+  ): Unit = {
+    val params = FetchParams(
+      requestVersion = requestVersion,
+      replicaId = replicaId,
+      maxWaitMs = maxWaitMs,
+      minBytes = minBytes,
+      maxBytes = maxBytes,
+      isolation = isolation,
+      clientMetadata = clientMetadata
+    )
+
+    replicaManager.fetchMessages(
+      params,
+      fetchInfos,
+      quota,
+      responseCallback
+    )
+  }
+
   private def setupReplicaManagerWithMockedPurgatories(
     timer: MockTimer,
     brokerId: Int = 0,
@@ -3142,14 +3167,11 @@ class ReplicaManagerTest {
 
       // Send a produce request and advance the highwatermark
       val leaderResponse = sendProducerAppend(replicaManager, topicPartition, numOfRecords)
-      fetchMessages(
+      fetchPartitionAsFollower(
         replicaManager,
-        otherId,
         topicIdPartition,
         new PartitionData(Uuid.ZERO_UUID, numOfRecords, 0, Int.MaxValue, Optional.empty()),
-        Int.MaxValue,
-        IsolationLevel.READ_UNCOMMITTED,
-        None
+        replicaId = otherId
       )
       assertEquals(Errors.NONE, leaderResponse.get.error)
 
@@ -3211,14 +3233,11 @@ class ReplicaManagerTest {
 
       // Send a produce request and advance the highwatermark
       val leaderResponse = sendProducerAppend(replicaManager, topicPartition, numOfRecords)
-      fetchMessages(
+      fetchPartitionAsFollower(
         replicaManager,
-        otherId,
         topicIdPartition,
         new PartitionData(Uuid.ZERO_UUID, numOfRecords, 0, Int.MaxValue, Optional.empty()),
-        Int.MaxValue,
-        IsolationLevel.READ_UNCOMMITTED,
-        None
+        replicaId = otherId
       )
       assertEquals(Errors.NONE, leaderResponse.get.error)
 
@@ -3484,15 +3503,15 @@ class ReplicaManagerTest {
       assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition))
 
       // Send a fetch request
-      val fetchCallback = fetchMessages(
+      val fetchCallback = fetchPartitionAsFollower(
         replicaManager,
-        otherId,
         topicIdPartition,
         new PartitionData(Uuid.ZERO_UUID, 0, 0, Int.MaxValue, Optional.empty()),
-        Int.MaxValue,
-        IsolationLevel.READ_UNCOMMITTED,
-        None
+        replicaId = otherId,
+        minBytes = Int.MaxValue,
+        maxWaitMs = 1000
       )
+      assertFalse(fetchCallback.hasFired)
 
       // Change the local replica to follower
       val followerTopicsDelta = topicsChangeDelta(leaderMetadataImage.topics(), localId, false)