You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/06/05 21:36:56 UTC

[kafka] branch trunk updated: KAFKA-8400; Do not update follower replica state if the log read failed (#6814)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 0e95c9f  KAFKA-8400; Do not update follower replica state if the log read failed (#6814)
0e95c9f is described below

commit 0e95c9f3a829110f0cd8c3695f40ba47f146fef7
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Wed Jun 5 14:36:43 2019 -0700

    KAFKA-8400; Do not update follower replica state if the log read failed (#6814)
    
    This patch checks for errors handling a fetch request before updating follower state. Previously we were unsafely passing the failed `LogReadResult` with most fields set to -1 into `Replica` to update follower state. Additionally, this patch attempts to improve the test coverage for ISR shrinking and expansion logic in `Partition`.
    
    Reviewers: Guozhang Wang <wa...@gmail.com>
---
 core/src/main/scala/kafka/cluster/Partition.scala  | 135 ++++---
 core/src/main/scala/kafka/cluster/Replica.scala    |  74 ++--
 core/src/main/scala/kafka/log/Log.scala            |   2 +-
 .../scala/kafka/server/LogOffsetMetadata.scala     |   6 +-
 .../kafka/server/ReplicaAlterLogDirsThread.scala   |   2 +-
 .../scala/kafka/server/ReplicaFetcherThread.scala  |   6 +-
 .../main/scala/kafka/server/ReplicaManager.scala   |  98 +++--
 .../integration/kafka/api/ConsumerBounceTest.scala |   2 +-
 .../admin/ReassignPartitionsClusterTest.scala      |   4 +-
 .../scala/unit/kafka/cluster/PartitionTest.scala   | 443 ++++++++++++++++++---
 .../scala/unit/kafka/cluster/ReplicaTest.scala     |  10 +-
 .../server/HighwatermarkPersistenceTest.scala      |  22 +-
 .../unit/kafka/server/ISRExpirationTest.scala      |  81 ++--
 .../scala/unit/kafka/server/LogRecoveryTest.scala  |  10 +-
 .../kafka/server/ReplicaFetcherThreadTest.scala    |  20 +-
 .../kafka/server/ReplicaManagerQuotasTest.scala    |   6 +-
 .../unit/kafka/server/ReplicaManagerTest.scala     |  92 ++++-
 .../scala/unit/kafka/server/SimpleFetchTest.scala  |  15 +-
 18 files changed, 726 insertions(+), 302 deletions(-)

diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala
index a6cce32..22d9dff 100755
--- a/core/src/main/scala/kafka/cluster/Partition.scala
+++ b/core/src/main/scala/kafka/cluster/Partition.scala
@@ -243,7 +243,7 @@ class Partition(val topicPartition: TopicPartition,
     new Gauge[Long] {
       def value = {
         leaderReplicaIfLocal.map { replica =>
-          replica.highWatermark.messageOffset - replica.lastStableOffset.messageOffset
+          replica.highWatermark - replica.lastStableOffset
         }.getOrElse(0)
       }
     },
@@ -527,11 +527,18 @@ class Partition(val topicPartition: TopicPartition,
 
       if (isNewLeader) {
         // construct the high watermark metadata for the new leader replica
-        leaderReplica.convertHWToLocalOffsetMetadata()
+        leaderReplica.maybeFetchHighWatermarkOffsetMetadata()
         // mark local replica as the leader after converting hw
         leaderReplicaIdOpt = Some(localBrokerId)
         // reset log end offset for remote replicas
-        assignedReplicas.filter(_.brokerId != localBrokerId).foreach(_.updateLogReadResult(LogReadResult.UnknownLogReadResult))
+        assignedReplicas.filter(_.brokerId != localBrokerId).foreach { replica =>
+          replica.updateFetchState(
+            followerFetchOffsetMetadata = LogOffsetMetadata.UnknownOffsetMetadata,
+            followerStartOffset = Log.UnknownOffset,
+            followerFetchTimeMs = 0L,
+            leaderEndOffset = Log.UnknownOffset
+          )
+        }
       }
       // we may need to increment high watermark since ISR could be down to 1
       (maybeIncrementLeaderHW(leaderReplica), isNewLeader)
@@ -581,28 +588,43 @@ class Partition(val topicPartition: TopicPartition,
    * Update the follower's state in the leader based on the last fetch request. See
    * [[kafka.cluster.Replica#updateLogReadResult]] for details.
    *
-   * @return true if the leader's log start offset or high watermark have been updated
+   * @return true if the follower's fetch state was updated, false if the followerId is not recognized
    */
-  def updateReplicaLogReadResult(replica: Replica, logReadResult: LogReadResult): Boolean = {
-    val replicaId = replica.brokerId
-    // No need to calculate low watermark if there is no delayed DeleteRecordsRequest
-    val oldLeaderLW = if (delayedOperations.numDelayedDelete > 0) lowWatermarkIfLeader else -1L
-    replica.updateLogReadResult(logReadResult)
-    val newLeaderLW = if (delayedOperations.numDelayedDelete > 0) lowWatermarkIfLeader else -1L
-    // check if the LW of the partition has incremented
-    // since the replica's logStartOffset may have incremented
-    val leaderLWIncremented = newLeaderLW > oldLeaderLW
-    // check if we need to expand ISR to include this replica
-    // if it is not in the ISR yet
-    val leaderHWIncremented = maybeExpandIsr(replicaId, logReadResult)
-
-    val result = leaderLWIncremented || leaderHWIncremented
-    // some delayed operations may be unblocked after HW or LW changed
-    if (result)
-      tryCompleteDelayedRequests()
+  def updateFollowerFetchState(followerId: Int,
+                               followerFetchOffsetMetadata: LogOffsetMetadata,
+                               followerStartOffset: Long,
+                               followerFetchTimeMs: Long,
+                               leaderEndOffset: Long): Boolean = {
+
+    getReplica(followerId) match {
+      case Some(followerReplica) =>
+        // No need to calculate low watermark if there is no delayed DeleteRecordsRequest
+        val oldLeaderLW = if (delayedOperations.numDelayedDelete > 0) lowWatermarkIfLeader else -1L
+        followerReplica.updateFetchState(
+          followerFetchOffsetMetadata,
+          followerStartOffset,
+          followerFetchTimeMs,
+          leaderEndOffset)
+        val newLeaderLW = if (delayedOperations.numDelayedDelete > 0) lowWatermarkIfLeader else -1L
+        // check if the LW of the partition has incremented
+        // since the replica's logStartOffset may have incremented
+        val leaderLWIncremented = newLeaderLW > oldLeaderLW
+        // check if we need to expand ISR to include this replica
+        // if it is not in the ISR yet
+        val followerFetchOffset = followerFetchOffsetMetadata.messageOffset
+        val leaderHWIncremented = maybeExpandIsr(followerReplica, followerFetchTimeMs)
+
+        // some delayed operations may be unblocked after HW or LW changed
+        if (leaderLWIncremented || leaderHWIncremented)
+          tryCompleteDelayedRequests()
+
+        debug(s"Recorded replica $followerId log end offset (LEO) position " +
+          s"$followerFetchOffset and log start offset $followerStartOffset.")
+        true
 
-    debug(s"Recorded replica $replicaId log end offset (LEO) position ${logReadResult.info.fetchOffsetMetadata.messageOffset}.")
-    result
+      case None =>
+        false
+    }
   }
 
   /**
@@ -621,19 +643,15 @@ class Partition(val topicPartition: TopicPartition,
    *
    * @return true if the high watermark has been updated
    */
-  def maybeExpandIsr(replicaId: Int, logReadResult: LogReadResult): Boolean = {
+  private def maybeExpandIsr(followerReplica: Replica, followerFetchTimeMs: Long): Boolean = {
     inWriteLock(leaderIsrUpdateLock) {
       // check if this replica needs to be added to the ISR
       leaderReplicaIfLocal match {
         case Some(leaderReplica) =>
-          val replica = getReplica(replicaId).get
-          val leaderHW = leaderReplica.highWatermark
-          val fetchOffset = logReadResult.info.fetchOffsetMetadata.messageOffset
-          if (!inSyncReplicas.contains(replica) &&
-             assignedReplicas.map(_.brokerId).contains(replicaId) &&
-             replica.logEndOffsetMetadata.offsetDiff(leaderHW) >= 0 &&
-             leaderEpochStartOffsetOpt.exists(fetchOffset >= _)) {
-            val newInSyncReplicas = inSyncReplicas + replica
+          val leaderHighwatermark = leaderReplica.highWatermark
+          if (!inSyncReplicas.contains(followerReplica) && isFollowerInSync(followerReplica, leaderHighwatermark)) {
+            val newInSyncReplicas = inSyncReplicas + followerReplica
+
             info(s"Expanding ISR from ${inSyncReplicas.map(_.brokerId).mkString(",")} " +
               s"to ${newInSyncReplicas.map(_.brokerId).mkString(",")}")
 
@@ -642,12 +660,17 @@ class Partition(val topicPartition: TopicPartition,
           }
           // check if the HW of the partition can now be incremented
           // since the replica may already be in the ISR and its LEO has just incremented
-          maybeIncrementLeaderHW(leaderReplica, logReadResult.fetchTimeMs)
+          maybeIncrementLeaderHW(leaderReplica, followerFetchTimeMs)
         case None => false // nothing to do if no longer leader
       }
     }
   }
 
+  private def isFollowerInSync(followerReplica: Replica, highWatermark: Long): Boolean = {
+    val followerEndOffset = followerReplica.logEndOffset
+    followerEndOffset >= highWatermark && leaderEpochStartOffsetOpt.exists(followerEndOffset >= _)
+  }
+
   /*
    * Returns a tuple where the first element is a boolean indicating whether enough replicas reached `requiredOffset`
    * and the second element is an error (which would be `Errors.NONE` for no error).
@@ -672,7 +695,7 @@ class Partition(val topicPartition: TopicPartition,
         }
 
         val minIsr = leaderReplica.log.get.config.minInSyncReplicas
-        if (leaderReplica.highWatermark.messageOffset >= requiredOffset) {
+        if (leaderReplica.highWatermark >= requiredOffset) {
           /*
            * The topic may be configured not to accept messages if there are not enough replicas in ISR
            * in this scenario the request was already appended locally and then added to the purgatory before the ISR was shrunk
@@ -711,18 +734,18 @@ class Partition(val topicPartition: TopicPartition,
       curTime - replica.lastCaughtUpTimeMs <= replicaLagTimeMaxMs || inSyncReplicas.contains(replica)
     }.map(_.logEndOffsetMetadata)
     val newHighWatermark = allLogEndOffsets.min(new LogOffsetMetadata.OffsetOrdering)
-    val oldHighWatermark = leaderReplica.highWatermark
+    val oldHighWatermark = leaderReplica.highWatermarkMetadata
 
     // Ensure that the high watermark increases monotonically. We also update the high watermark when the new
     // offset metadata is on a newer segment, which occurs whenever the log is rolled to a new segment.
     if (oldHighWatermark.messageOffset < newHighWatermark.messageOffset ||
       (oldHighWatermark.messageOffset == newHighWatermark.messageOffset && oldHighWatermark.onOlderSegment(newHighWatermark))) {
-      leaderReplica.highWatermark = newHighWatermark
+      leaderReplica.highWatermarkMetadata = newHighWatermark
       debug(s"High watermark updated to $newHighWatermark")
       true
     } else {
       def logEndOffsetString(r: Replica) = s"replica ${r.brokerId}: ${r.logEndOffsetMetadata}"
-      debug(s"Skipping update high watermark since new hw $newHighWatermark is not larger than old hw $oldHighWatermark. " +
+      trace(s"Skipping update high watermark since new hw $newHighWatermark is not larger than old hw $oldHighWatermark. " +
         s"All current LEOs are ${assignedReplicas.map(logEndOffsetString)}")
       false
     }
@@ -747,7 +770,7 @@ class Partition(val topicPartition: TopicPartition,
    */
   private def tryCompleteDelayedRequests(): Unit = delayedOperations.checkAndCompleteAll()
 
-  def maybeShrinkIsr(replicaMaxLagTimeMs: Long) {
+  def maybeShrinkIsr(replicaMaxLagTimeMs: Long): Unit = {
     val leaderHWIncremented = inWriteLock(leaderIsrUpdateLock) {
       leaderReplicaIfLocal match {
         case Some(leaderReplica) =>
@@ -758,7 +781,7 @@ class Partition(val topicPartition: TopicPartition,
             info("Shrinking ISR from %s to %s. Leader: (highWatermark: %d, endOffset: %d). Out of sync replicas: %s."
               .format(inSyncReplicas.map(_.brokerId).mkString(","),
                 newInSyncReplicas.map(_.brokerId).mkString(","),
-                leaderReplica.highWatermark.messageOffset,
+                leaderReplica.highWatermarkMetadata.messageOffset,
                 leaderReplica.logEndOffset,
                 outOfSyncReplicas.map { replica =>
                   s"(brokerId: ${replica.brokerId}, endOffset: ${replica.logEndOffset})"
@@ -784,6 +807,14 @@ class Partition(val topicPartition: TopicPartition,
       tryCompleteDelayedRequests()
   }
 
+  private def isFollowerOutOfSync(followerReplica: Replica,
+                                  leaderEndOffset: Long,
+                                  currentTimeMs: Long,
+                                  maxLagMs: Long): Boolean = {
+    followerReplica.logEndOffset != leaderEndOffset &&
+      (currentTimeMs - followerReplica.lastCaughtUpTimeMs) > maxLagMs
+  }
+
   def getOutOfSyncReplicas(leaderReplica: Replica, maxLagMs: Long): Set[Replica] = {
     /**
      * If the follower already has the same leo as the leader, it will not be considered as out-of-sync,
@@ -798,13 +829,9 @@ class Partition(val topicPartition: TopicPartition,
      *
      **/
     val candidateReplicas = inSyncReplicas - leaderReplica
-
-    val laggingReplicas = candidateReplicas.filter(r =>
-      r.logEndOffset != leaderReplica.logEndOffset && (time.milliseconds - r.lastCaughtUpTimeMs) > maxLagMs)
-    if (laggingReplicas.nonEmpty)
-      debug("Lagging replicas are %s".format(laggingReplicas.map(_.brokerId).mkString(",")))
-
-    laggingReplicas
+    val currentTimeMs = time.milliseconds()
+    val leaderEndOffset = leaderReplica.logEndOffset
+    candidateReplicas.filter(r => isFollowerOutOfSync(r, leaderEndOffset, currentTimeMs, maxLagMs))
   }
 
   private def doAppendRecordsToFollowerOrFutureReplica(records: MemoryRecords, isFuture: Boolean): Option[LogAppendInfo] = {
@@ -903,10 +930,10 @@ class Partition(val topicPartition: TopicPartition,
      * where data gets appended to the log immediately after the replica has consumed from it
      * This can cause a replica to always be out of sync.
      */
-    val initialHighWatermark = localReplica.highWatermark.messageOffset
+    val initialHighWatermark = localReplica.highWatermark
     val initialLogStartOffset = localReplica.logStartOffset
     val initialLogEndOffset = localReplica.logEndOffset
-    val initialLastStableOffset = localReplica.lastStableOffset.messageOffset
+    val initialLastStableOffset = localReplica.lastStableOffset
 
     val maxOffsetOpt = fetchIsolation match {
       case FetchLogEnd => None
@@ -940,8 +967,8 @@ class Partition(val topicPartition: TopicPartition,
     val localReplica = localReplicaWithEpochOrException(currentLeaderEpoch, fetchOnlyFromLeader)
 
     val lastFetchableOffset = isolationLevel match {
-      case Some(IsolationLevel.READ_COMMITTED) => localReplica.lastStableOffset.messageOffset
-      case Some(IsolationLevel.READ_UNCOMMITTED) => localReplica.highWatermark.messageOffset
+      case Some(IsolationLevel.READ_COMMITTED) => localReplica.lastStableOffset
+      case Some(IsolationLevel.READ_UNCOMMITTED) => localReplica.highWatermark
       case None => localReplica.logEndOffset
     }
 
@@ -954,10 +981,10 @@ class Partition(val topicPartition: TopicPartition,
     // Only consider throwing an error if we get a client request (isolationLevel is defined) and the start offset
     // is lagging behind the high watermark
     val maybeOffsetsError: Option[ApiException] = leaderEpochStartOffsetOpt
-      .filter(epochStart => isolationLevel.isDefined && epochStart > localReplica.highWatermark.messageOffset)
+      .filter(epochStart => isolationLevel.isDefined && epochStart > localReplica.highWatermark)
       .map(epochStart => Errors.OFFSET_NOT_AVAILABLE.exception(s"Failed to fetch offsets for " +
         s"partition $topicPartition with leader $epochLogString as this partition's " +
-        s"high watermark (${localReplica.highWatermark.messageOffset}) is lagging behind the " +
+        s"high watermark (${localReplica.highWatermark}) is lagging behind the " +
         s"start offset from the beginning of this epoch ($epochStart)."))
 
     def getOffsetByTimestamp: Option[TimestampAndOffset] = {
@@ -1011,7 +1038,7 @@ class Partition(val topicPartition: TopicPartition,
     if (!isFromConsumer) {
       allOffsets
     } else {
-      val hw = localReplica.highWatermark.messageOffset
+      val hw = localReplica.highWatermark
       if (allOffsets.exists(_ > hw))
         hw +: allOffsets.dropWhile(_ > hw)
       else
@@ -1038,7 +1065,7 @@ class Partition(val topicPartition: TopicPartition,
           throw new PolicyViolationException(s"Records of partition $topicPartition can not be deleted due to the configured policy")
 
         val convertedOffset = if (offset == DeleteRecordsRequest.HIGH_WATERMARK)
-          leaderReplica.highWatermark.messageOffset
+          leaderReplica.highWatermark
         else
           offset
 
diff --git a/core/src/main/scala/kafka/cluster/Replica.scala b/core/src/main/scala/kafka/cluster/Replica.scala
index 32055d7..831233e 100644
--- a/core/src/main/scala/kafka/cluster/Replica.scala
+++ b/core/src/main/scala/kafka/cluster/Replica.scala
@@ -19,7 +19,7 @@ package kafka.cluster
 
 import kafka.log.{Log, LogOffsetSnapshot}
 import kafka.utils.Logging
-import kafka.server.{LogOffsetMetadata, LogReadResult, OffsetAndEpoch}
+import kafka.server.{LogOffsetMetadata, OffsetAndEpoch}
 import org.apache.kafka.common.{KafkaException, TopicPartition}
 import org.apache.kafka.common.errors.OffsetOutOfRangeException
 import org.apache.kafka.common.utils.Time
@@ -30,13 +30,13 @@ class Replica(val brokerId: Int,
               initialHighWatermarkValue: Long = 0L,
               @volatile var log: Option[Log] = None) extends Logging {
   // the high watermark offset value, in non-leader replicas only its message offsets are kept
-  @volatile private[this] var highWatermarkMetadata = new LogOffsetMetadata(initialHighWatermarkValue)
+  @volatile private[this] var _highWatermarkMetadata = new LogOffsetMetadata(initialHighWatermarkValue)
   // the log end offset value, kept in all replicas;
   // for local replica it is the log's end offset, for remote replicas its value is only updated by follower fetch
   @volatile private[this] var _logEndOffsetMetadata = LogOffsetMetadata.UnknownOffsetMetadata
   // the log start offset value, kept in all replicas;
   // for local replica it is the log's start offset, for remote replicas its value is only updated by follower fetch
-  @volatile private[this] var _logStartOffset = Log.UnknownLogStartOffset
+  @volatile private[this] var _logStartOffset = Log.UnknownOffset
 
   // The log end offset value at the time the leader received the last FetchRequest from this follower
   // This is used to determine the lastCaughtUpTimeMs of the follower
@@ -69,16 +69,19 @@ class Replica(val brokerId: Int,
    * fetch request is always smaller than the leader's LEO, which can happen if small produce requests are received at
    * high frequency.
    */
-  def updateLogReadResult(logReadResult: LogReadResult) {
-    if (logReadResult.info.fetchOffsetMetadata.messageOffset >= logReadResult.leaderLogEndOffset)
-      _lastCaughtUpTimeMs = math.max(_lastCaughtUpTimeMs, logReadResult.fetchTimeMs)
-    else if (logReadResult.info.fetchOffsetMetadata.messageOffset >= lastFetchLeaderLogEndOffset)
+  def updateFetchState(followerFetchOffsetMetadata: LogOffsetMetadata,
+                       followerStartOffset: Long,
+                       followerFetchTimeMs: Long,
+                       leaderEndOffset: Long): Unit = {
+    if (followerFetchOffsetMetadata.messageOffset >= leaderEndOffset)
+      _lastCaughtUpTimeMs = math.max(_lastCaughtUpTimeMs, followerFetchTimeMs)
+    else if (followerFetchOffsetMetadata.messageOffset >= lastFetchLeaderLogEndOffset)
       _lastCaughtUpTimeMs = math.max(_lastCaughtUpTimeMs, lastFetchTimeMs)
 
-    logStartOffset = logReadResult.followerLogStartOffset
-    logEndOffsetMetadata = logReadResult.info.fetchOffsetMetadata
-    lastFetchLeaderLogEndOffset = logReadResult.leaderLogEndOffset
-    lastFetchTimeMs = logReadResult.fetchTimeMs
+    logStartOffset = followerStartOffset
+    logEndOffsetMetadata = followerFetchOffsetMetadata
+    lastFetchLeaderLogEndOffset = leaderEndOffset
+    lastFetchTimeMs = followerFetchTimeMs
   }
 
   def resetLastCaughtUpTime(curLeaderLogEndOffset: Long, curTimeMs: Long, lastCaughtUpTimeMs: Long) {
@@ -127,9 +130,9 @@ class Replica(val brokerId: Int,
    */
   def maybeIncrementLogStartOffset(newLogStartOffset: Long) {
     if (isLocal) {
-      if (newLogStartOffset > highWatermark.messageOffset)
+      if (newLogStartOffset > highWatermark)
         throw new OffsetOutOfRangeException(s"Cannot increment the log start offset to $newLogStartOffset of partition $topicPartition " +
-          s"since it is larger than the high watermark ${highWatermark.messageOffset}")
+          s"since it is larger than the high watermark $highWatermark")
       log.get.maybeIncrementLogStartOffset(newLogStartOffset)
     } else {
       throw new KafkaException(s"Should not try to delete records on partition $topicPartition's non-local replica $brokerId")
@@ -152,20 +155,26 @@ class Replica(val brokerId: Int,
     else
       _logStartOffset
 
-  def highWatermark_=(newHighWatermark: LogOffsetMetadata) {
+  def highWatermarkMetadata_=(newHighWatermarkMetadata: LogOffsetMetadata) {
     if (isLocal) {
-      if (newHighWatermark.messageOffset < 0)
+      if (newHighWatermarkMetadata.messageOffset < 0)
         throw new IllegalArgumentException("High watermark offset should be non-negative")
 
-      highWatermarkMetadata = newHighWatermark
-      log.foreach(_.onHighWatermarkIncremented(newHighWatermark.messageOffset))
-      trace(s"Setting high watermark for replica $brokerId partition $topicPartition to [$newHighWatermark]")
+      _highWatermarkMetadata = newHighWatermarkMetadata
+      log.foreach(_.onHighWatermarkIncremented(newHighWatermarkMetadata.messageOffset))
+      trace(s"Setting high watermark for replica $brokerId partition $topicPartition to [$newHighWatermarkMetadata]")
     } else {
       throw new KafkaException(s"Should not set high watermark on partition $topicPartition's non-local replica $brokerId")
     }
   }
 
-  def highWatermark: LogOffsetMetadata = highWatermarkMetadata
+  def highWatermark_=(newHighWatermark: Long): Unit = {
+    highWatermarkMetadata = LogOffsetMetadata(newHighWatermark)
+  }
+
+  def highWatermarkMetadata: LogOffsetMetadata = _highWatermarkMetadata
+
+  def highWatermark: Long = _highWatermarkMetadata.messageOffset
 
   /**
    * The last stable offset (LSO) is defined as the first offset such that all lower offsets have been "decided."
@@ -174,30 +183,33 @@ class Replica(val brokerId: Int,
    * to the high watermark if there are no transactional messages in the log. Note also that the LSO cannot advance
    * beyond the high watermark.
    */
-  def lastStableOffset: LogOffsetMetadata = {
+  def lastStableOffsetMetadata: LogOffsetMetadata = {
     log.map { log =>
       log.firstUnstableOffset match {
-        case Some(offsetMetadata) if offsetMetadata.messageOffset < highWatermark.messageOffset => offsetMetadata
-        case _ => highWatermark
+        case Some(offsetMetadata) if offsetMetadata.messageOffset < highWatermark => offsetMetadata
+        case _ => highWatermarkMetadata
       }
     }.getOrElse(throw new KafkaException(s"Cannot fetch last stable offset on partition $topicPartition's " +
       s"non-local replica $brokerId"))
   }
 
+  def lastStableOffset: Long = lastStableOffsetMetadata.messageOffset
+
   /*
    * Convert hw to local offset metadata by reading the log at the hw offset.
    * If the hw offset is out of range, return the first offset of the first log segment as the offset metadata.
    */
-  def convertHWToLocalOffsetMetadata() {
-    if (isLocal) {
-      highWatermarkMetadata = log.get.convertToOffsetMetadata(highWatermarkMetadata.messageOffset).getOrElse {
+  def maybeFetchHighWatermarkOffsetMetadata(): Unit = {
+    if (!isLocal)
+      throw new KafkaException(s"Should not construct complete high watermark on partition $topicPartition's non-local replica $brokerId")
+
+    if (highWatermarkMetadata.messageOffsetOnly) {
+      highWatermarkMetadata = log.get.convertToOffsetMetadata(highWatermark).getOrElse {
         log.get.convertToOffsetMetadata(logStartOffset).getOrElse {
           val firstSegmentOffset = log.get.logSegments.head.baseOffset
           new LogOffsetMetadata(firstSegmentOffset, firstSegmentOffset, 0)
         }
       }
-    } else {
-      throw new KafkaException(s"Should not construct complete high watermark on partition $topicPartition's non-local replica $brokerId")
     }
   }
 
@@ -205,8 +217,8 @@ class Replica(val brokerId: Int,
     LogOffsetSnapshot(
       logStartOffset = logStartOffset,
       logEndOffset = logEndOffsetMetadata,
-      highWatermark =  highWatermark,
-      lastStableOffset = lastStableOffset)
+      highWatermark =  highWatermarkMetadata,
+      lastStableOffset = lastStableOffsetMetadata)
   }
 
   override def equals(that: Any): Boolean = that match {
@@ -224,8 +236,8 @@ class Replica(val brokerId: Int,
     replicaString.append(s", isLocal=$isLocal")
     replicaString.append(s", lastCaughtUpTimeMs=$lastCaughtUpTimeMs")
     if (isLocal) {
-      replicaString.append(s", highWatermark=$highWatermark")
-      replicaString.append(s", lastStableOffset=$lastStableOffset")
+      replicaString.append(s", highWatermark=$highWatermarkMetadata")
+      replicaString.append(s", lastStableOffset=$lastStableOffsetMetadata")
     }
     replicaString.append(")")
     replicaString.toString
diff --git a/core/src/main/scala/kafka/log/Log.scala b/core/src/main/scala/kafka/log/Log.scala
index 9ab6fda..38c3617 100644
--- a/core/src/main/scala/kafka/log/Log.scala
+++ b/core/src/main/scala/kafka/log/Log.scala
@@ -2183,7 +2183,7 @@ object Log {
   private[log] val DeleteDirPattern = Pattern.compile(s"^(\\S+)-(\\S+)\\.(\\S+)$DeleteDirSuffix")
   private[log] val FutureDirPattern = Pattern.compile(s"^(\\S+)-(\\S+)\\.(\\S+)$FutureDirSuffix")
 
-  val UnknownLogStartOffset = -1L
+  val UnknownOffset = -1L
 
   def apply(dir: File,
             config: LogConfig,
diff --git a/core/src/main/scala/kafka/server/LogOffsetMetadata.scala b/core/src/main/scala/kafka/server/LogOffsetMetadata.scala
index effbaa0..67afac6 100644
--- a/core/src/main/scala/kafka/server/LogOffsetMetadata.scala
+++ b/core/src/main/scala/kafka/server/LogOffsetMetadata.scala
@@ -17,11 +17,11 @@
 
 package kafka.server
 
+import kafka.log.Log
 import org.apache.kafka.common.KafkaException
 
 object LogOffsetMetadata {
   val UnknownOffsetMetadata = new LogOffsetMetadata(-1, 0, 0)
-  val UnknownSegBaseOffset = -1L
   val UnknownFilePosition = -1
 
   class OffsetOrdering extends Ordering[LogOffsetMetadata] {
@@ -39,7 +39,7 @@ object LogOffsetMetadata {
  *  3. the physical position on the located segment
  */
 case class LogOffsetMetadata(messageOffset: Long,
-                             segmentBaseOffset: Long = LogOffsetMetadata.UnknownSegBaseOffset,
+                             segmentBaseOffset: Long = Log.UnknownOffset,
                              relativePositionInSegment: Int = LogOffsetMetadata.UnknownFilePosition) {
 
   // check if this offset is already on an older segment compared with the given offset
@@ -76,7 +76,7 @@ case class LogOffsetMetadata(messageOffset: Long,
 
   // decide if the offset metadata only contains message offset info
   def messageOffsetOnly: Boolean = {
-    segmentBaseOffset == LogOffsetMetadata.UnknownSegBaseOffset && relativePositionInSegment == LogOffsetMetadata.UnknownFilePosition
+    segmentBaseOffset == Log.UnknownOffset && relativePositionInSegment == LogOffsetMetadata.UnknownFilePosition
   }
 
   override def toString = s"(offset=$messageOffset segment=[$segmentBaseOffset:$relativePositionInSegment])"
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
index 0622b30..8b45501 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
@@ -111,7 +111,7 @@ class ReplicaAlterLogDirsThread(name: String,
 
     val logAppendInfo = partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = true)
     val futureReplicaHighWatermark = futureReplica.logEndOffset.min(partitionData.highWatermark)
-    futureReplica.highWatermark = new LogOffsetMetadata(futureReplicaHighWatermark)
+    futureReplica.highWatermark = futureReplicaHighWatermark
     futureReplica.maybeIncrementLogStartOffset(partitionData.logStartOffset)
 
     if (partition.maybeReplaceCurrentWithFutureReplica())
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index b1b5dd0..947e16a 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -168,7 +168,7 @@ class ReplicaFetcherThread(name: String,
     // for the follower replica, we do not need to keep
     // its segment base offset the physical position,
     // these values will be computed upon making the leader
-    replica.highWatermark = new LogOffsetMetadata(followerHighWatermark)
+    replica.highWatermark = followerHighWatermark
     replica.maybeIncrementLogStartOffset(leaderLogStartOffset)
     if (isTraceEnabled)
       trace(s"Follower set replica high watermark for partition $topicPartition to $followerHighWatermark")
@@ -282,9 +282,9 @@ class ReplicaFetcherThread(name: String,
 
     partition.truncateTo(offsetTruncationState.offset, isFuture = false)
 
-    if (offsetTruncationState.offset < replica.highWatermark.messageOffset)
+    if (offsetTruncationState.offset < replica.highWatermark)
       warn(s"Truncating $tp to offset ${offsetTruncationState.offset} below high watermark " +
-        s"${replica.highWatermark.messageOffset}")
+        s"${replica.highWatermark}")
 
     // mark the future replica for truncation only when we do last truncation
     if (offsetTruncationState.truncationCompleted)
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index e9ab738..8b383c2 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -108,16 +108,6 @@ case class FetchPartitionData(error: Errors = Errors.NONE,
                               lastStableOffset: Option[Long],
                               abortedTransactions: Option[List[AbortedTransaction]])
 
-object LogReadResult {
-  val UnknownLogReadResult = LogReadResult(info = FetchDataInfo(LogOffsetMetadata.UnknownOffsetMetadata, MemoryRecords.EMPTY),
-                                           highWatermark = -1L,
-                                           leaderLogStartOffset = -1L,
-                                           leaderLogEndOffset = -1L,
-                                           followerLogStartOffset = -1L,
-                                           fetchTimeMs = -1L,
-                                           readSize = -1,
-                                           lastStableOffset = None)
-}
 
 /**
  * Trait to represent the state of hosted partitions. We create a concrete (active) Partition
@@ -620,7 +610,7 @@ class ReplicaManager(val config: KafkaConfig,
             logManager.abortAndPauseCleaning(topicPartition)
 
             val initialFetchState = InitialFetchState(BrokerEndPoint(config.brokerId, "localhost", -1),
-              partition.getLeaderEpoch, futureReplica.highWatermark.messageOffset)
+              partition.getLeaderEpoch, futureReplica.highWatermark)
             replicaAlterLogDirsManager.addFetcherForPartitions(Map(topicPartition -> initialFetchState))
           }
 
@@ -685,7 +675,7 @@ class ReplicaManager(val config: KafkaConfig,
         if (isFuture)
           replica.logEndOffset - logEndOffset
         else
-          math.max(replica.highWatermark.messageOffset - logEndOffset, 0)
+          math.max(replica.highWatermark - logEndOffset, 0)
       case None =>
         // return -1L to indicate that the LEO lag is not available if the replica is not created or is offline
         DescribeLogDirsResponse.INVALID_OFFSET_LAG
@@ -849,7 +839,7 @@ class ReplicaManager(val config: KafkaConfig,
         hardMaxBytesLimit = hardMaxBytesLimit,
         readPartitionInfo = fetchInfos,
         quota = quota)
-      if (isFromFollower) updateFollowerLogReadResults(replicaId, result)
+      if (isFromFollower) updateFollowerFetchState(replicaId, result)
       else result
     }
 
@@ -967,14 +957,14 @@ class ReplicaManager(val config: KafkaConfig,
                  _: KafkaStorageException |
                  _: OffsetOutOfRangeException) =>
           LogReadResult(info = FetchDataInfo(LogOffsetMetadata.UnknownOffsetMetadata, MemoryRecords.EMPTY),
-                        highWatermark = -1L,
-                        leaderLogStartOffset = -1L,
-                        leaderLogEndOffset = -1L,
-                        followerLogStartOffset = -1L,
-                        fetchTimeMs = -1L,
-                        readSize = 0,
-                        lastStableOffset = None,
-                        exception = Some(e))
+            highWatermark = Log.UnknownOffset,
+            leaderLogStartOffset = Log.UnknownOffset,
+            leaderLogEndOffset = Log.UnknownOffset,
+            followerLogStartOffset = Log.UnknownOffset,
+            fetchTimeMs = -1L,
+            readSize = 0,
+            lastStableOffset = None,
+            exception = Some(e))
         case e: Throwable =>
           brokerTopicStats.topicStats(tp.topic).failedFetchRequestRate.mark()
           brokerTopicStats.allTopicsStats.failedFetchRequestRate.mark()
@@ -984,14 +974,14 @@ class ReplicaManager(val config: KafkaConfig,
             s"on partition $tp: $fetchInfo", e)
 
           LogReadResult(info = FetchDataInfo(LogOffsetMetadata.UnknownOffsetMetadata, MemoryRecords.EMPTY),
-                        highWatermark = -1L,
-                        leaderLogStartOffset = -1L,
-                        leaderLogEndOffset = -1L,
-                        followerLogStartOffset = -1L,
-                        fetchTimeMs = -1L,
-                        readSize = 0,
-                        lastStableOffset = None,
-                        exception = Some(e))
+            highWatermark = Log.UnknownOffset,
+            leaderLogStartOffset = Log.UnknownOffset,
+            leaderLogEndOffset = Log.UnknownOffset,
+            followerLogStartOffset = Log.UnknownOffset,
+            fetchTimeMs = -1L,
+            readSize = 0,
+            lastStableOffset = None,
+            exception = Some(e))
       }
     }
 
@@ -1161,7 +1151,7 @@ class ReplicaManager(val config: KafkaConfig,
               logManager.abortAndPauseCleaning(topicPartition)
 
               futureReplicasAndInitialOffset.put(topicPartition, InitialFetchState(leader,
-                partition.getLeaderEpoch, replica.highWatermark.messageOffset))
+                partition.getLeaderEpoch, replica.highWatermark))
             }
           }
         }
@@ -1356,7 +1346,7 @@ class ReplicaManager(val config: KafkaConfig,
         val partitionsToMakeFollowerWithLeaderAndOffset = partitionsToMakeFollower.map { partition =>
           val leader = metadataCache.getAliveBrokers.find(_.id == partition.leaderReplicaIdOpt.get).get
             .brokerEndPoint(config.interBrokerListenerName)
-          val fetchOffset = partition.localReplicaOrException.highWatermark.messageOffset
+          val fetchOffset = partition.localReplicaOrException.highWatermark
           partition.topicPartition -> InitialFetchState(leader, partition.getLeaderEpoch, fetchOffset)
        }.toMap
 
@@ -1394,33 +1384,41 @@ class ReplicaManager(val config: KafkaConfig,
   }
 
   /**
-   * Update the follower's fetch state in the leader based on the last fetch request and update `readResult`,
-   * if the follower replica is not recognized to be one of the assigned replicas. Do not update
-   * `readResult` otherwise, so that log start/end offset and high watermark is consistent with
+   * Update the follower's fetch state on the leader based on the last fetch request and update `readResult`.
+   * If the follower replica is not recognized to be one of the assigned replicas, do not update
+   * `readResult` so that log start/end offset and high watermark is consistent with
    * records in fetch response. Log start/end offset and high watermark may change not only due to
    * this fetch request, e.g., rolling new log segment and removing old log segment may move log
    * start offset further than the last offset in the fetched records. The followers will get the
    * updated leader's state in the next fetch response.
    */
-  private def updateFollowerLogReadResults(replicaId: Int,
-                                           readResults: Seq[(TopicPartition, LogReadResult)]): Seq[(TopicPartition, LogReadResult)] = {
-    debug(s"Recording follower broker $replicaId log end offsets: $readResults")
+  private def updateFollowerFetchState(followerId: Int,
+                                       readResults: Seq[(TopicPartition, LogReadResult)]): Seq[(TopicPartition, LogReadResult)] = {
     readResults.map { case (topicPartition, readResult) =>
-      var updatedReadResult = readResult
-      nonOfflinePartition(topicPartition) match {
-        case Some(partition) =>
-          partition.getReplica(replicaId) match {
-            case Some(replica) =>
-              partition.updateReplicaLogReadResult(replica, readResult)
-            case None =>
-              warn(s"Leader $localBrokerId failed to record follower $replicaId's position " +
+      val updatedReadResult = if (readResult.error != Errors.NONE) {
+        debug(s"Skipping update of fetch state for follower $followerId since the " +
+          s"log read returned error ${readResult.error}")
+        readResult
+      } else {
+        nonOfflinePartition(topicPartition) match {
+          case Some(partition) =>
+            if (partition.updateFollowerFetchState(followerId,
+              followerFetchOffsetMetadata = readResult.info.fetchOffsetMetadata,
+              followerStartOffset = readResult.followerLogStartOffset,
+              followerFetchTimeMs = readResult.fetchTimeMs,
+              leaderEndOffset = readResult.leaderLogEndOffset)) {
+              readResult
+            } else {
+              warn(s"Leader $localBrokerId failed to record follower $followerId's position " +
                 s"${readResult.info.fetchOffsetMetadata.messageOffset} since the replica is not recognized to be " +
                 s"one of the assigned replicas ${partition.assignedReplicas.map(_.brokerId).mkString(",")} " +
                 s"for partition $topicPartition. Empty records will be returned for this partition.")
-              updatedReadResult = readResult.withEmptyFetchInfo
-          }
-        case None =>
-          warn(s"While recording the replica LEO, the partition $topicPartition hasn't been created.")
+              readResult.withEmptyFetchInfo
+            }
+          case None =>
+            warn(s"While recording the replica LEO, the partition $topicPartition hasn't been created.")
+            readResult
+        }
       }
       topicPartition -> updatedReadResult
     }
@@ -1442,7 +1440,7 @@ class ReplicaManager(val config: KafkaConfig,
     }.filter(_.log.isDefined).toBuffer
     val replicasByDir = replicas.groupBy(_.log.get.dir.getParent)
     for ((dir, reps) <- replicasByDir) {
-      val hwms = reps.map(r => r.topicPartition -> r.highWatermark.messageOffset).toMap
+      val hwms = reps.map(r => r.topicPartition -> r.highWatermark).toMap
       try {
         highWatermarkCheckpoints.get(dir).foreach(_.write(hwms))
       } catch {
diff --git a/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala b/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
index ae6fc00..16a117b 100644
--- a/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
+++ b/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
@@ -132,7 +132,7 @@ class ConsumerBounceTest extends AbstractConsumerTest with Logging {
 
     // wait until all the followers have synced the last HW with leader
     TestUtils.waitUntilTrue(() => servers.forall(server =>
-      server.replicaManager.localReplica(tp).get.highWatermark.messageOffset == numRecords
+      server.replicaManager.localReplica(tp).get.highWatermark == numRecords
     ), "Failed to update high watermark for followers after timeout")
 
     val scheduler = new BounceBrokerScheduler(numIters)
diff --git a/core/src/test/scala/unit/kafka/admin/ReassignPartitionsClusterTest.scala b/core/src/test/scala/unit/kafka/admin/ReassignPartitionsClusterTest.scala
index c5763ad..d604cd0 100644
--- a/core/src/test/scala/unit/kafka/admin/ReassignPartitionsClusterTest.scala
+++ b/core/src/test/scala/unit/kafka/admin/ReassignPartitionsClusterTest.scala
@@ -104,10 +104,10 @@ class ReassignPartitionsClusterTest extends ZooKeeperTestHarness with Logging {
     )
 
     assertEquals(100, newLeaderServer.replicaManager.localReplicaOrException(topicPartition)
-      .highWatermark.messageOffset)
+      .highWatermark)
     val newFollowerServer = servers.find(_.config.brokerId == 102).get
     TestUtils.waitUntilTrue(() => newFollowerServer.replicaManager.localReplicaOrException(topicPartition)
-      .highWatermark.messageOffset == 100,
+      .highWatermark == 100,
       "partition follower's highWatermark should be 100")
   }
 
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index 2b12a04..650dce4 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -487,24 +487,20 @@ class PartitionTest {
 
     // after makeLeader(() call, partition should know about all the replicas
     val leaderReplica = partition.getReplica(leader).get
-    val follower1Replica = partition.getReplica(follower1).get
-    val follower2Replica = partition.getReplica(follower2).get
 
     // append records with initial leader epoch
     partition.appendRecordsToLeader(batch1, isFromClient = true)
     partition.appendRecordsToLeader(batch2, isFromClient = true)
-    assertEquals("Expected leader's HW not move", leaderReplica.logStartOffset, leaderReplica.highWatermark.messageOffset)
+    assertEquals("Expected leader's HW not move", leaderReplica.logStartOffset, leaderReplica.highWatermark)
 
     // let the follower in ISR move leader's HW to move further but below LEO
-    def readResult(fetchInfo: FetchDataInfo, leaderReplica: Replica): LogReadResult = {
-      LogReadResult(info = fetchInfo,
-        highWatermark = leaderReplica.highWatermark.messageOffset,
-        leaderLogStartOffset = leaderReplica.logStartOffset,
-        leaderLogEndOffset = leaderReplica.logEndOffset,
-        followerLogStartOffset = 0,
-        fetchTimeMs = time.milliseconds,
-        readSize = 10240,
-        lastStableOffset = None)
+    def updateFollowerFetchState(followerId: Int, fetchOffsetMetadata: LogOffsetMetadata): Unit = {
+      partition.updateFollowerFetchState(
+        followerId,
+        followerFetchOffsetMetadata = fetchOffsetMetadata,
+        followerStartOffset = 0L,
+        followerFetchTimeMs = time.milliseconds(),
+        leaderEndOffset = leaderReplica.logEndOffset)
     }
 
     def fetchOffsetsForTimestamp(timestamp: Long, isolation: Option[IsolationLevel]): Either[ApiException, Option[TimestampAndOffset]] = {
@@ -524,19 +520,14 @@ class PartitionTest {
       List(leader, follower2, follower1), 1)))
       .thenReturn(Some(2))
 
-    // Update follower 1
-    partition.updateReplicaLogReadResult(
-      follower1Replica, readResult(FetchDataInfo(LogOffsetMetadata(0), batch1), leaderReplica))
-    partition.updateReplicaLogReadResult(
-      follower1Replica, readResult(FetchDataInfo(LogOffsetMetadata(2), batch2), leaderReplica))
+    updateFollowerFetchState(follower1, LogOffsetMetadata(0))
+    updateFollowerFetchState(follower1, LogOffsetMetadata(2))
 
-    partition.updateReplicaLogReadResult(
-      follower2Replica, readResult(FetchDataInfo(LogOffsetMetadata(0), batch1), leaderReplica))
-    partition.updateReplicaLogReadResult(
-      follower2Replica, readResult(FetchDataInfo(LogOffsetMetadata(2), batch2), leaderReplica))
+    updateFollowerFetchState(follower2, LogOffsetMetadata(0))
+    updateFollowerFetchState(follower2, LogOffsetMetadata(2))
 
     // At this point, the leader has gotten 5 writes, but followers have only fetched two
-    assertEquals(2, partition.localReplica.get.highWatermark.messageOffset)
+    assertEquals(2, partition.localReplica.get.highWatermark)
 
     // Get the LEO
     fetchOffsetsForTimestamp(ListOffsetRequest.LATEST_TIMESTAMP, None) match {
@@ -609,10 +600,8 @@ class PartitionTest {
       .thenReturn(Some(2))
 
     // Next fetch from replicas, HW is moved up to 5 (ahead of the LEO)
-    partition.updateReplicaLogReadResult(
-      follower1Replica, readResult(FetchDataInfo(LogOffsetMetadata(5), MemoryRecords.EMPTY), leaderReplica))
-    partition.updateReplicaLogReadResult(
-      follower2Replica, readResult(FetchDataInfo(LogOffsetMetadata(5), MemoryRecords.EMPTY), leaderReplica))
+    updateFollowerFetchState(follower1, LogOffsetMetadata(5))
+    updateFollowerFetchState(follower2, LogOffsetMetadata(5))
 
     // Error goes away
     fetchOffsetsForTimestamp(ListOffsetRequest.LATEST_TIMESTAMP, Some(IsolationLevel.READ_UNCOMMITTED)) match {
@@ -747,7 +736,7 @@ class PartitionTest {
     assertEquals(0L, fetchLatestOffset(isolationLevel = Some(IsolationLevel.READ_UNCOMMITTED)).offset)
     assertEquals(0L, fetchLatestOffset(isolationLevel = Some(IsolationLevel.READ_COMMITTED)).offset)
 
-    replica.highWatermark = LogOffsetMetadata(1L)
+    replica.highWatermark = 1L
 
     assertEquals(3L, fetchLatestOffset(isolationLevel = None).offset)
     assertEquals(1L, fetchLatestOffset(isolationLevel = Some(IsolationLevel.READ_UNCOMMITTED)).offset)
@@ -822,30 +811,25 @@ class PartitionTest {
 
     // after makeLeader(() call, partition should know about all the replicas
     val leaderReplica = partition.getReplica(leader).get
-    val follower1Replica = partition.getReplica(follower1).get
-    val follower2Replica = partition.getReplica(follower2).get
 
     // append records with initial leader epoch
     val lastOffsetOfFirstBatch = partition.appendRecordsToLeader(batch1, isFromClient = true).lastOffset
     partition.appendRecordsToLeader(batch2, isFromClient = true)
-    assertEquals("Expected leader's HW not move", leaderReplica.logStartOffset, leaderReplica.highWatermark.messageOffset)
+    assertEquals("Expected leader's HW not move", leaderReplica.logStartOffset, leaderReplica.highWatermark)
 
     // let the follower in ISR move leader's HW to move further but below LEO
-    def readResult(fetchInfo: FetchDataInfo, leaderReplica: Replica): LogReadResult = {
-      LogReadResult(info = fetchInfo,
-                    highWatermark = leaderReplica.highWatermark.messageOffset,
-                    leaderLogStartOffset = leaderReplica.logStartOffset,
-                    leaderLogEndOffset = leaderReplica.logEndOffset,
-                    followerLogStartOffset = 0,
-                    fetchTimeMs = time.milliseconds,
-                    readSize = 10240,
-                    lastStableOffset = None)
+    def updateFollowerFetchState(followerId: Int, fetchOffsetMetadata: LogOffsetMetadata): Unit = {
+      partition.updateFollowerFetchState(
+        followerId,
+        followerFetchOffsetMetadata = fetchOffsetMetadata,
+        followerStartOffset = 0L,
+        followerFetchTimeMs = time.milliseconds(),
+        leaderEndOffset = leaderReplica.logEndOffset)
     }
-    partition.updateReplicaLogReadResult(
-      follower2Replica, readResult(FetchDataInfo(LogOffsetMetadata(0), batch1), leaderReplica))
-    partition.updateReplicaLogReadResult(
-      follower2Replica, readResult(FetchDataInfo(LogOffsetMetadata(lastOffsetOfFirstBatch), batch2), leaderReplica))
-    assertEquals("Expected leader's HW", lastOffsetOfFirstBatch, leaderReplica.highWatermark.messageOffset)
+
+    updateFollowerFetchState(follower2, LogOffsetMetadata(0))
+    updateFollowerFetchState(follower2, LogOffsetMetadata(lastOffsetOfFirstBatch))
+    assertEquals("Expected leader's HW", lastOffsetOfFirstBatch, leaderReplica.highWatermark)
 
     // current leader becomes follower and then leader again (without any new records appended)
     val followerState = new LeaderAndIsrRequest.PartitionState(controllerEpoch, follower2, leaderEpoch + 1, isr, 1,
@@ -862,18 +846,15 @@ class PartitionTest {
     partition.appendRecordsToLeader(batch3, isFromClient = true)
 
     // fetch from follower not in ISR from log start offset should not add this follower to ISR
-    partition.updateReplicaLogReadResult(follower1Replica,
-                                         readResult(FetchDataInfo(LogOffsetMetadata(0), batch1), leaderReplica))
-    partition.updateReplicaLogReadResult(follower1Replica,
-                                         readResult(FetchDataInfo(LogOffsetMetadata(lastOffsetOfFirstBatch), batch2), leaderReplica))
+    updateFollowerFetchState(follower1, LogOffsetMetadata(0))
+    updateFollowerFetchState(follower1, LogOffsetMetadata(lastOffsetOfFirstBatch))
     assertEquals("ISR", Set[Integer](leader, follower2), partition.inSyncReplicas.map(_.brokerId))
 
     // fetch from the follower not in ISR from start offset of the current leader epoch should
     // add this follower to ISR
     when(stateStore.expandIsr(controllerEpoch, new LeaderAndIsr(leader, leaderEpoch + 2,
       List(leader, follower2, follower1), 1))).thenReturn(Some(2))
-    partition.updateReplicaLogReadResult(follower1Replica,
-      readResult(FetchDataInfo(LogOffsetMetadata(currentLeaderEpochStartOffset), batch3), leaderReplica))
+    updateFollowerFetchState(follower1, LogOffsetMetadata(currentLeaderEpochStartOffset))
     assertEquals("ISR", Set[Integer](leader, follower1, follower2), partition.inSyncReplicas.map(_.brokerId))
   }
 
@@ -1013,18 +994,350 @@ class PartitionTest {
   }
 
   @Test
+  def testUpdateFollowerFetchState(): Unit = {
+    val log = logManager.getOrCreateLog(topicPartition, logConfig)
+    seedLogData(log, numRecords = 6, leaderEpoch = 4)
+
+    val controllerId = 0
+    val controllerEpoch = 0
+    val leaderEpoch = 5
+    val remoteBrokerId = brokerId + 1
+    val replicas = List[Integer](brokerId, remoteBrokerId).asJava
+    val isr = replicas
+
+    doNothing().when(delayedOperations).checkAndCompleteFetch()
+
+    partition.getOrCreateReplica(brokerId, isNew = false, offsetCheckpoints)
+
+    val initializeTimeMs = time.milliseconds()
+    assertTrue("Expected become leader transition to succeed",
+      partition.makeLeader(controllerId, new LeaderAndIsrRequest.PartitionState(controllerEpoch, brokerId,
+        leaderEpoch, isr, 1, replicas, true), 0, offsetCheckpoints))
+
+    val remoteReplica = partition.getReplica(remoteBrokerId).get
+    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset)
+    assertEquals(Log.UnknownOffset, remoteReplica.logStartOffset)
+
+    time.sleep(500)
+
+    partition.updateFollowerFetchState(remoteBrokerId,
+      followerFetchOffsetMetadata = LogOffsetMetadata(3),
+      followerStartOffset = 0L,
+      followerFetchTimeMs = time.milliseconds(),
+      leaderEndOffset = 6L)
+
+    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertEquals(3L, remoteReplica.logEndOffset)
+    assertEquals(0L, remoteReplica.logStartOffset)
+
+    time.sleep(500)
+
+    partition.updateFollowerFetchState(remoteBrokerId,
+      followerFetchOffsetMetadata = LogOffsetMetadata(6L),
+      followerStartOffset = 0L,
+      followerFetchTimeMs = time.milliseconds(),
+      leaderEndOffset = 6L)
+
+    assertEquals(time.milliseconds(), remoteReplica.lastCaughtUpTimeMs)
+    assertEquals(6L, remoteReplica.logEndOffset)
+    assertEquals(0L, remoteReplica.logStartOffset)
+  }
+
+  @Test
+  def testIsrExpansion(): Unit = {
+    val log = logManager.getOrCreateLog(topicPartition, logConfig)
+    seedLogData(log, numRecords = 10, leaderEpoch = 4)
+
+    val controllerId = 0
+    val controllerEpoch = 0
+    val leaderEpoch = 5
+    val remoteBrokerId = brokerId + 1
+    val replicas = List[Integer](brokerId, remoteBrokerId).asJava
+    val isr = List[Integer](brokerId).asJava
+
+    doNothing().when(delayedOperations).checkAndCompleteFetch()
+
+    partition.getOrCreateReplica(brokerId, isNew = false, offsetCheckpoints)
+    assertTrue("Expected become leader transition to succeed",
+      partition.makeLeader(controllerId, new LeaderAndIsrRequest.PartitionState(controllerEpoch, brokerId,
+        leaderEpoch, isr, 1, replicas, true), 0, offsetCheckpoints))
+    assertEquals(Set(brokerId), partition.inSyncReplicas.map(_.brokerId))
+
+    val remoteReplica = partition.getReplica(remoteBrokerId).get
+    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset)
+    assertEquals(Log.UnknownOffset, remoteReplica.logStartOffset)
+
+    partition.updateFollowerFetchState(remoteBrokerId,
+      followerFetchOffsetMetadata = LogOffsetMetadata(3),
+      followerStartOffset = 0L,
+      followerFetchTimeMs = time.milliseconds(),
+      leaderEndOffset = 6L)
+
+    assertEquals(Set(brokerId), partition.inSyncReplicas.map(_.brokerId))
+    assertEquals(3L, remoteReplica.logEndOffset)
+    assertEquals(0L, remoteReplica.logStartOffset)
+
+    // The next update should bring the follower back into the ISR
+    val updatedLeaderAndIsr = LeaderAndIsr(
+      leader = brokerId,
+      leaderEpoch = leaderEpoch,
+      isr = List(brokerId, remoteBrokerId),
+      zkVersion = 1)
+    when(stateStore.expandIsr(controllerEpoch, updatedLeaderAndIsr)).thenReturn(Some(2))
+
+    partition.updateFollowerFetchState(remoteBrokerId,
+      followerFetchOffsetMetadata = LogOffsetMetadata(10),
+      followerStartOffset = 0L,
+      followerFetchTimeMs = time.milliseconds(),
+      leaderEndOffset = 6L)
+
+    assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicas.map(_.brokerId))
+    assertEquals(10L, remoteReplica.logEndOffset)
+    assertEquals(0L, remoteReplica.logStartOffset)
+  }
+
+  @Test
+  def testIsrNotExpandedIfUpdateFails(): Unit = {
+    val log = logManager.getOrCreateLog(topicPartition, logConfig)
+    seedLogData(log, numRecords = 10, leaderEpoch = 4)
+
+    val controllerId = 0
+    val controllerEpoch = 0
+    val leaderEpoch = 5
+    val remoteBrokerId = brokerId + 1
+    val replicas = List[Integer](brokerId, remoteBrokerId).asJava
+    val isr = List[Integer](brokerId).asJava
+
+    doNothing().when(delayedOperations).checkAndCompleteFetch()
+
+    partition.getOrCreateReplica(brokerId, isNew = false, offsetCheckpoints)
+    assertTrue("Expected become leader transition to succeed",
+      partition.makeLeader(controllerId, new LeaderAndIsrRequest.PartitionState(controllerEpoch, brokerId,
+        leaderEpoch, isr, 1, replicas, true), 0, offsetCheckpoints))
+    assertEquals(Set(brokerId), partition.inSyncReplicas.map(_.brokerId))
+
+    val remoteReplica = partition.getReplica(remoteBrokerId).get
+    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset)
+    assertEquals(Log.UnknownOffset, remoteReplica.logStartOffset)
+
+    // Mock the expected ISR update failure
+    val updatedLeaderAndIsr = LeaderAndIsr(
+      leader = brokerId,
+      leaderEpoch = leaderEpoch,
+      isr = List(brokerId, remoteBrokerId),
+      zkVersion = 1)
+    when(stateStore.expandIsr(controllerEpoch, updatedLeaderAndIsr)).thenReturn(None)
+
+    partition.updateFollowerFetchState(remoteBrokerId,
+      followerFetchOffsetMetadata = LogOffsetMetadata(10),
+      followerStartOffset = 0L,
+      followerFetchTimeMs = time.milliseconds(),
+      leaderEndOffset = 10L)
+
+    // Follower state is updated, but the ISR has not expanded
+    assertEquals(Set(brokerId), partition.inSyncReplicas.map(_.brokerId))
+    assertEquals(10L, remoteReplica.logEndOffset)
+    assertEquals(0L, remoteReplica.logStartOffset)
+  }
+
+  @Test
+  def testMaybeShrinkIsr(): Unit = {
+    val log = logManager.getOrCreateLog(topicPartition, logConfig)
+    seedLogData(log, numRecords = 10, leaderEpoch = 4)
+
+    val controllerId = 0
+    val controllerEpoch = 0
+    val leaderEpoch = 5
+    val remoteBrokerId = brokerId + 1
+    val replicas = List[Integer](brokerId, remoteBrokerId).asJava
+    val isr = List[Integer](brokerId, remoteBrokerId).asJava
+
+    doNothing().when(delayedOperations).checkAndCompleteFetch()
+
+    val initializeTimeMs = time.milliseconds()
+    partition.getOrCreateReplica(brokerId, isNew = false, offsetCheckpoints)
+    assertTrue("Expected become leader transition to succeed",
+      partition.makeLeader(controllerId, new LeaderAndIsrRequest.PartitionState(controllerEpoch, brokerId,
+        leaderEpoch, isr, 1, replicas, true), 0, offsetCheckpoints))
+    assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicas.map(_.brokerId))
+    assertEquals(0L, partition.localReplicaOrException.highWatermark)
+
+    val remoteReplica = partition.getReplica(remoteBrokerId).get
+    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset)
+    assertEquals(Log.UnknownOffset, remoteReplica.logStartOffset)
+
+    // On initialization, the replica is considered caught up and should not be removed
+    partition.maybeShrinkIsr(10000)
+    assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicas.map(_.brokerId))
+
+    // If enough time passes without a fetch update, the ISR should shrink
+    time.sleep(10001)
+    val updatedLeaderAndIsr = LeaderAndIsr(
+      leader = brokerId,
+      leaderEpoch = leaderEpoch,
+      isr = List(brokerId),
+      zkVersion = 1)
+    when(stateStore.shrinkIsr(controllerEpoch, updatedLeaderAndIsr)).thenReturn(Some(2))
+
+    partition.maybeShrinkIsr(10000)
+    assertEquals(Set(brokerId), partition.inSyncReplicas.map(_.brokerId))
+    assertEquals(10L, partition.localReplicaOrException.highWatermark)
+  }
+
+  @Test
+  def testShouldNotShrinkIsrIfPreviousFetchIsCaughtUp(): Unit = {
+    val log = logManager.getOrCreateLog(topicPartition, logConfig)
+    seedLogData(log, numRecords = 10, leaderEpoch = 4)
+
+    val controllerId = 0
+    val controllerEpoch = 0
+    val leaderEpoch = 5
+    val remoteBrokerId = brokerId + 1
+    val replicas = List[Integer](brokerId, remoteBrokerId).asJava
+    val isr = List[Integer](brokerId, remoteBrokerId).asJava
+
+    doNothing().when(delayedOperations).checkAndCompleteFetch()
+
+    val initializeTimeMs = time.milliseconds()
+    partition.getOrCreateReplica(brokerId, isNew = false, offsetCheckpoints)
+    assertTrue("Expected become leader transition to succeed",
+      partition.makeLeader(controllerId, new LeaderAndIsrRequest.PartitionState(controllerEpoch, brokerId,
+        leaderEpoch, isr, 1, replicas, true), 0, offsetCheckpoints))
+    assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicas.map(_.brokerId))
+    assertEquals(0L, partition.localReplicaOrException.highWatermark)
+
+    val remoteReplica = partition.getReplica(remoteBrokerId).get
+    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset)
+    assertEquals(Log.UnknownOffset, remoteReplica.logStartOffset)
+
+    // There is a short delay before the first fetch. The follower is not yet caught up to the log end.
+    time.sleep(5000)
+    val firstFetchTimeMs = time.milliseconds()
+    partition.updateFollowerFetchState(remoteBrokerId,
+      followerFetchOffsetMetadata = LogOffsetMetadata(5),
+      followerStartOffset = 0L,
+      followerFetchTimeMs = firstFetchTimeMs,
+      leaderEndOffset = 10L)
+    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertEquals(5L, partition.localReplicaOrException.highWatermark)
+    assertEquals(5L, remoteReplica.logEndOffset)
+    assertEquals(0L, remoteReplica.logStartOffset)
+
+    // Some new data is appended, but the follower catches up to the old end offset.
+    // The total elapsed time from initialization is larger than the max allowed replica lag.
+    time.sleep(5001)
+    seedLogData(log, numRecords = 5, leaderEpoch = leaderEpoch)
+    partition.updateFollowerFetchState(remoteBrokerId,
+      followerFetchOffsetMetadata = LogOffsetMetadata(10),
+      followerStartOffset = 0L,
+      followerFetchTimeMs = time.milliseconds(),
+      leaderEndOffset = 15L)
+    assertEquals(firstFetchTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertEquals(10L, partition.localReplicaOrException.highWatermark)
+    assertEquals(10L, remoteReplica.logEndOffset)
+    assertEquals(0L, remoteReplica.logStartOffset)
+
+    // The ISR should not be shrunk because the follower has caught up with the leader at the
+    // time of the first fetch.
+    partition.maybeShrinkIsr(10000)
+    assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicas.map(_.brokerId))
+  }
+
+  @Test
+  def testShouldNotShrinkIsrIfFollowerCaughtUpToLogEnd(): Unit = {
+    val log = logManager.getOrCreateLog(topicPartition, logConfig)
+    seedLogData(log, numRecords = 10, leaderEpoch = 4)
+
+    val controllerId = 0
+    val controllerEpoch = 0
+    val leaderEpoch = 5
+    val remoteBrokerId = brokerId + 1
+    val replicas = List[Integer](brokerId, remoteBrokerId).asJava
+    val isr = List[Integer](brokerId, remoteBrokerId).asJava
+
+    doNothing().when(delayedOperations).checkAndCompleteFetch()
+
+    val initializeTimeMs = time.milliseconds()
+    partition.getOrCreateReplica(brokerId, isNew = false, offsetCheckpoints)
+    assertTrue("Expected become leader transition to succeed",
+      partition.makeLeader(controllerId, new LeaderAndIsrRequest.PartitionState(controllerEpoch, brokerId,
+        leaderEpoch, isr, 1, replicas, true), 0, offsetCheckpoints))
+    assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicas.map(_.brokerId))
+    assertEquals(0L, partition.localReplicaOrException.highWatermark)
+
+    val remoteReplica = partition.getReplica(remoteBrokerId).get
+    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset)
+    assertEquals(Log.UnknownOffset, remoteReplica.logStartOffset)
+
+    // The follower catches up to the log end immediately.
+    partition.updateFollowerFetchState(remoteBrokerId,
+      followerFetchOffsetMetadata = LogOffsetMetadata(10),
+      followerStartOffset = 0L,
+      followerFetchTimeMs = time.milliseconds(),
+      leaderEndOffset = 10L)
+    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertEquals(10L, partition.localReplicaOrException.highWatermark)
+    assertEquals(10L, remoteReplica.logEndOffset)
+    assertEquals(0L, remoteReplica.logStartOffset)
+
+    // Sleep longer than the max allowed follower lag
+    time.sleep(10001)
+
+    // The ISR should not be shrunk because the follower is caught up to the leader's log end
+    partition.maybeShrinkIsr(10000)
+    assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicas.map(_.brokerId))
+  }
+
+  @Test
+  def testIsrNotShrunkIfUpdateFails(): Unit = {
+    val log = logManager.getOrCreateLog(topicPartition, logConfig)
+    seedLogData(log, numRecords = 10, leaderEpoch = 4)
+
+    val controllerId = 0
+    val controllerEpoch = 0
+    val leaderEpoch = 5
+    val remoteBrokerId = brokerId + 1
+    val replicas = List[Integer](brokerId, remoteBrokerId).asJava
+    val isr = List[Integer](brokerId, remoteBrokerId).asJava
+
+    doNothing().when(delayedOperations).checkAndCompleteFetch()
+
+    val initializeTimeMs = time.milliseconds()
+    partition.getOrCreateReplica(brokerId, isNew = false, offsetCheckpoints)
+    assertTrue("Expected become leader transition to succeed",
+      partition.makeLeader(controllerId, new LeaderAndIsrRequest.PartitionState(controllerEpoch, brokerId,
+        leaderEpoch, isr, 1, replicas, true), 0, offsetCheckpoints))
+    assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicas.map(_.brokerId))
+    assertEquals(0L, partition.localReplicaOrException.highWatermark)
+
+    val remoteReplica = partition.getReplica(remoteBrokerId).get
+    assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs)
+    assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset)
+    assertEquals(Log.UnknownOffset, remoteReplica.logStartOffset)
+
+    time.sleep(10001)
+
+    // Mock the expected ISR update failure
+    val updatedLeaderAndIsr = LeaderAndIsr(
+      leader = brokerId,
+      leaderEpoch = leaderEpoch,
+      isr = List(brokerId),
+      zkVersion = 1)
+    when(stateStore.shrinkIsr(controllerEpoch, updatedLeaderAndIsr)).thenReturn(None)
+
+    partition.maybeShrinkIsr(10000)
+    assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicas.map(_.brokerId))
+    assertEquals(0L, partition.localReplicaOrException.highWatermark)
+  }
+
+  @Test
   def testUseCheckpointToInitializeHighWatermark(): Unit = {
     val log = logManager.getOrCreateLog(topicPartition, logConfig)
-    log.appendAsLeader(MemoryRecords.withRecords(0L, CompressionType.NONE, 0,
-      new SimpleRecord("k1".getBytes, "v1".getBytes),
-      new SimpleRecord("k2".getBytes, "v2".getBytes),
-      new SimpleRecord("k3".getBytes, "v3".getBytes),
-      new SimpleRecord("k4".getBytes, "v4".getBytes)
-    ), leaderEpoch = 0)
-    log.appendAsLeader(MemoryRecords.withRecords(0L, CompressionType.NONE, 5,
-      new SimpleRecord("k5".getBytes, "v5".getBytes),
-      new SimpleRecord("k5".getBytes, "v5".getBytes)
-    ), leaderEpoch = 5)
+    seedLogData(log, numRecords = 6, leaderEpoch = 5)
 
     when(offsetCheckpoints.fetch(logDir1.getAbsolutePath, topicPartition))
       .thenReturn(Some(4L))
@@ -1035,7 +1348,7 @@ class PartitionTest {
     val leaderState = new LeaderAndIsrRequest.PartitionState(controllerEpoch, brokerId,
       6, replicas, 1, replicas, false)
     partition.makeLeader(controllerId, leaderState, 0, offsetCheckpoints)
-    assertEquals(4, partition.localReplicaOrException.highWatermark.messageOffset)
+    assertEquals(4, partition.localReplicaOrException.highWatermark)
   }
 
   @Test
@@ -1061,4 +1374,12 @@ class PartitionTest {
     assertEquals(Set(), Metrics.defaultRegistry().allMetrics().asScala.keySet.filter(_.getType == "Partition"))
   }
 
+  private def seedLogData(log: Log, numRecords: Int, leaderEpoch: Int): Unit = {
+    for (i <- 0 until numRecords) {
+      val records = MemoryRecords.withRecords(0L, CompressionType.NONE, leaderEpoch,
+        new SimpleRecord(s"k$i".getBytes, s"v$i".getBytes))
+      log.appendAsLeader(records, leaderEpoch)
+    }
+  }
+
 }
diff --git a/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala b/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala
index d63901a..bacdc81 100644
--- a/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala
@@ -76,7 +76,7 @@ class ReplicaTest {
       initialHighWatermarkValue = initialHighWatermark,
       log = Some(log))
 
-    assertEquals(initialHighWatermark, replica.highWatermark.messageOffset)
+    assertEquals(initialHighWatermark, replica.highWatermark)
 
     val expiredTimestamp = time.milliseconds() - 1000
     for (i <- 0 until 100) {
@@ -100,13 +100,13 @@ class ReplicaTest {
 
     // ensure we have at least a few segments so the test case is not trivial
     assertTrue(log.numberOfSegments > 5)
-    assertEquals(0L, replica.highWatermark.messageOffset)
+    assertEquals(0L, replica.highWatermark)
     assertEquals(0L, replica.logStartOffset)
     assertEquals(100L, replica.logEndOffset)
 
     for (hw <- 0 to 100) {
-      replica.highWatermark = new LogOffsetMetadata(hw)
-      assertEquals(hw, replica.highWatermark.messageOffset)
+      replica.highWatermark = hw
+      assertEquals(hw, replica.highWatermark)
       log.deleteOldSegments()
       assertTrue(replica.logStartOffset <= hw)
 
@@ -134,7 +134,7 @@ class ReplicaTest {
       log.appendAsLeader(records, leaderEpoch = 0)
     }
 
-    replica.highWatermark = new LogOffsetMetadata(25L)
+    replica.highWatermark = 25L
     replica.maybeIncrementLogStartOffset(26L)
   }
 
diff --git a/core/src/test/scala/unit/kafka/server/HighwatermarkPersistenceTest.scala b/core/src/test/scala/unit/kafka/server/HighwatermarkPersistenceTest.scala
index 3da22bb..61c521b 100755
--- a/core/src/test/scala/unit/kafka/server/HighwatermarkPersistenceTest.scala
+++ b/core/src/test/scala/unit/kafka/server/HighwatermarkPersistenceTest.scala
@@ -81,12 +81,12 @@ class HighwatermarkPersistenceTest {
       partition0.addReplicaIfNotExists(followerReplicaPartition0)
       replicaManager.checkpointHighWatermarks()
       fooPartition0Hw = hwmFor(replicaManager, topic, 0)
-      assertEquals(leaderReplicaPartition0.highWatermark.messageOffset, fooPartition0Hw)
+      assertEquals(leaderReplicaPartition0.highWatermark, fooPartition0Hw)
       // set the high watermark for local replica
-      partition0.localReplica.get.highWatermark = new LogOffsetMetadata(5L)
+      partition0.localReplica.get.highWatermark = 5L
       replicaManager.checkpointHighWatermarks()
       fooPartition0Hw = hwmFor(replicaManager, topic, 0)
-      assertEquals(leaderReplicaPartition0.highWatermark.messageOffset, fooPartition0Hw)
+      assertEquals(leaderReplicaPartition0.highWatermark, fooPartition0Hw)
       EasyMock.verify(zkClient)
     } finally {
       // shutdown the replica manager upon test completion
@@ -125,12 +125,12 @@ class HighwatermarkPersistenceTest {
       topic1Partition0.addReplicaIfNotExists(leaderReplicaTopic1Partition0)
       replicaManager.checkpointHighWatermarks()
       topic1Partition0Hw = hwmFor(replicaManager, topic1, 0)
-      assertEquals(leaderReplicaTopic1Partition0.highWatermark.messageOffset, topic1Partition0Hw)
+      assertEquals(leaderReplicaTopic1Partition0.highWatermark, topic1Partition0Hw)
       // set the high watermark for local replica
-      topic1Partition0.localReplica.get.highWatermark = new LogOffsetMetadata(5L)
+      topic1Partition0.localReplica.get.highWatermark = 5L
       replicaManager.checkpointHighWatermarks()
       topic1Partition0Hw = hwmFor(replicaManager, topic1, 0)
-      assertEquals(5L, leaderReplicaTopic1Partition0.highWatermark.messageOffset)
+      assertEquals(5L, leaderReplicaTopic1Partition0.highWatermark)
       assertEquals(5L, topic1Partition0Hw)
       // add another partition and set highwatermark
       val t2p0 = new TopicPartition(topic2, 0)
@@ -142,13 +142,13 @@ class HighwatermarkPersistenceTest {
       topic2Partition0.addReplicaIfNotExists(leaderReplicaTopic2Partition0)
       replicaManager.checkpointHighWatermarks()
       var topic2Partition0Hw = hwmFor(replicaManager, topic2, 0)
-      assertEquals(leaderReplicaTopic2Partition0.highWatermark.messageOffset, topic2Partition0Hw)
+      assertEquals(leaderReplicaTopic2Partition0.highWatermark, topic2Partition0Hw)
       // set the highwatermark for local replica
-      topic2Partition0.localReplica.get.highWatermark = new LogOffsetMetadata(15L)
-      assertEquals(15L, leaderReplicaTopic2Partition0.highWatermark.messageOffset)
+      topic2Partition0.localReplica.get.highWatermark = 15L
+      assertEquals(15L, leaderReplicaTopic2Partition0.highWatermark)
       // change the highwatermark for topic1
-      topic1Partition0.localReplica.get.highWatermark = new LogOffsetMetadata(10L)
-      assertEquals(10L, leaderReplicaTopic1Partition0.highWatermark.messageOffset)
+      topic1Partition0.localReplica.get.highWatermark = 10L
+      assertEquals(10L, leaderReplicaTopic1Partition0.highWatermark)
       replicaManager.checkpointHighWatermarks()
       // verify checkpointed hw for topic 2
       topic2Partition0Hw = hwmFor(replicaManager, topic2, 0)
diff --git a/core/src/test/scala/unit/kafka/server/ISRExpirationTest.scala b/core/src/test/scala/unit/kafka/server/ISRExpirationTest.scala
index 1dd4b24..0ee0fa1 100644
--- a/core/src/test/scala/unit/kafka/server/ISRExpirationTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ISRExpirationTest.scala
@@ -25,7 +25,6 @@ import kafka.log.{Log, LogManager}
 import kafka.utils._
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.metrics.Metrics
-import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.utils.Time
 import org.easymock.EasyMock
 import org.junit.Assert._
@@ -82,14 +81,11 @@ class IsrExpirationTest {
 
     // let the follower catch up to the Leader logEndOffset - 1
     for (replica <- partition0.assignedReplicas - leaderReplica)
-      replica.updateLogReadResult(new LogReadResult(info = FetchDataInfo(new LogOffsetMetadata(leaderLogEndOffset - 1), MemoryRecords.EMPTY),
-                                                    highWatermark = leaderLogEndOffset - 1,
-                                                    leaderLogStartOffset = 0L,
-                                                    leaderLogEndOffset = leaderLogEndOffset,
-                                                    followerLogStartOffset = 0L,
-                                                    fetchTimeMs = time.milliseconds,
-                                                    readSize = -1,
-                                                    lastStableOffset = None))
+      replica.updateFetchState(
+        followerFetchOffsetMetadata = new LogOffsetMetadata(leaderLogEndOffset - 1),
+        followerStartOffset = 0L,
+        followerFetchTimeMs= time.milliseconds,
+        leaderEndOffset = leaderLogEndOffset)
     var partition0OSR = partition0.getOutOfSyncReplicas(leaderReplica, configs.head.replicaLagTimeMaxMs)
     assertEquals("No replica should be out of sync", Set.empty[Int], partition0OSR.map(_.brokerId))
 
@@ -137,14 +133,11 @@ class IsrExpirationTest {
 
     // Make the remote replica not read to the end of log. It should be not be out of sync for at least 100 ms
     for (replica <- partition0.assignedReplicas - leaderReplica)
-      replica.updateLogReadResult(new LogReadResult(info = FetchDataInfo(new LogOffsetMetadata(leaderLogEndOffset - 2), MemoryRecords.EMPTY),
-                                                    highWatermark = leaderLogEndOffset - 2,
-                                                    leaderLogStartOffset = 0L,
-                                                    leaderLogEndOffset = leaderLogEndOffset,
-                                                    followerLogStartOffset = 0L,
-                                                    fetchTimeMs = time.milliseconds,
-                                                    readSize = -1,
-                                                    lastStableOffset = None))
+      replica.updateFetchState(
+        followerFetchOffsetMetadata = new LogOffsetMetadata(leaderLogEndOffset - 2),
+        followerStartOffset = 0L,
+        followerFetchTimeMs= time.milliseconds,
+        leaderEndOffset = leaderLogEndOffset)
 
     // Simulate 2 fetch requests spanning more than 100 ms which do not read to the end of the log.
     // The replicas will no longer be in ISR. We do 2 fetches because we want to simulate the case where the replica is lagging but is not stuck
@@ -154,14 +147,11 @@ class IsrExpirationTest {
     time.sleep(75)
 
     (partition0.assignedReplicas - leaderReplica).foreach { r =>
-      r.updateLogReadResult(new LogReadResult(info = FetchDataInfo(new LogOffsetMetadata(leaderLogEndOffset - 1), MemoryRecords.EMPTY),
-                            highWatermark = leaderLogEndOffset - 1,
-                            leaderLogStartOffset = 0L,
-                            leaderLogEndOffset = leaderLogEndOffset,
-                            followerLogStartOffset = 0L,
-                            fetchTimeMs = time.milliseconds,
-                            readSize = -1,
-                            lastStableOffset = None))
+      r.updateFetchState(
+        followerFetchOffsetMetadata = new LogOffsetMetadata(leaderLogEndOffset - 1),
+        followerStartOffset = 0L,
+        followerFetchTimeMs= time.milliseconds,
+        leaderEndOffset = leaderLogEndOffset)
     }
     partition0OSR = partition0.getOutOfSyncReplicas(leaderReplica, configs.head.replicaLagTimeMaxMs)
     assertEquals("No replica should be out of sync", Set.empty[Int], partition0OSR.map(_.brokerId))
@@ -174,14 +164,11 @@ class IsrExpirationTest {
 
     // Now actually make a fetch to the end of the log. The replicas should be back in ISR
     (partition0.assignedReplicas - leaderReplica).foreach { r =>
-      r.updateLogReadResult(new LogReadResult(info = FetchDataInfo(new LogOffsetMetadata(leaderLogEndOffset), MemoryRecords.EMPTY),
-                            highWatermark = leaderLogEndOffset,
-                            leaderLogStartOffset = 0L,
-                            leaderLogEndOffset = leaderLogEndOffset,
-                            followerLogStartOffset = 0L,
-                            fetchTimeMs = time.milliseconds,
-                            readSize = -1,
-                            lastStableOffset = None))
+      r.updateFetchState(
+        followerFetchOffsetMetadata = new LogOffsetMetadata(leaderLogEndOffset),
+        followerStartOffset = 0L,
+        followerFetchTimeMs= time.milliseconds,
+        leaderEndOffset = leaderLogEndOffset)
     }
     partition0OSR = partition0.getOutOfSyncReplicas(leaderReplica, configs.head.replicaLagTimeMaxMs)
     assertEquals("No replica should be out of sync", Set.empty[Int], partition0OSR.map(_.brokerId))
@@ -203,14 +190,12 @@ class IsrExpirationTest {
 
     // let the follower catch up to the Leader logEndOffset
     for (replica <- partition0.assignedReplicas - leaderReplica)
-      replica.updateLogReadResult(new LogReadResult(info = FetchDataInfo(new LogOffsetMetadata(leaderLogEndOffset), MemoryRecords.EMPTY),
-        highWatermark = leaderLogEndOffset,
-        leaderLogStartOffset = 0L,
-        leaderLogEndOffset = leaderLogEndOffset,
-        followerLogStartOffset = 0L,
-        fetchTimeMs = time.milliseconds,
-        readSize = -1,
-        lastStableOffset = None))
+      replica.updateFetchState(
+        followerFetchOffsetMetadata = new LogOffsetMetadata(leaderLogEndOffset),
+        followerStartOffset = 0L,
+        followerFetchTimeMs= time.milliseconds,
+        leaderEndOffset = leaderLogEndOffset)
+
     var partition0OSR = partition0.getOutOfSyncReplicas(leaderReplica, configs.head.replicaLagTimeMaxMs)
     assertEquals("No replica should be out of sync", Set.empty[Int], partition0OSR.map(_.brokerId))
 
@@ -236,14 +221,12 @@ class IsrExpirationTest {
     partition.inSyncReplicas = allReplicas.toSet
     // set lastCaughtUpTime to current time
     for (replica <- partition.assignedReplicas - leaderReplica)
-      replica.updateLogReadResult(new LogReadResult(info = FetchDataInfo(new LogOffsetMetadata(0L), MemoryRecords.EMPTY),
-                                                    highWatermark = 0L,
-                                                    leaderLogStartOffset = 0L,
-                                                    leaderLogEndOffset = 0L,
-                                                    followerLogStartOffset = 0L,
-                                                    fetchTimeMs = time.milliseconds,
-                                                    readSize = -1,
-                                                    lastStableOffset = None))
+      replica.updateFetchState(
+        followerFetchOffsetMetadata = new LogOffsetMetadata(0L),
+        followerStartOffset = 0L,
+        followerFetchTimeMs= time.milliseconds,
+        leaderEndOffset = 0L)
+
     // set the leader and its hw and the hw update time
     partition.leaderReplicaIdOpt = Some(leaderId)
     partition
diff --git a/core/src/test/scala/unit/kafka/server/LogRecoveryTest.scala b/core/src/test/scala/unit/kafka/server/LogRecoveryTest.scala
index 6ab3138..e868f6c 100755
--- a/core/src/test/scala/unit/kafka/server/LogRecoveryTest.scala
+++ b/core/src/test/scala/unit/kafka/server/LogRecoveryTest.scala
@@ -104,7 +104,7 @@ class LogRecoveryTest extends ZooKeeperTestHarness {
 
     // give some time for the follower 1 to record leader HW
     TestUtils.waitUntilTrue(() =>
-      server2.replicaManager.localReplica(topicPartition).get.highWatermark.messageOffset == numMessages,
+      server2.replicaManager.localReplica(topicPartition).get.highWatermark == numMessages,
       "Failed to update high watermark for follower after timeout")
 
     servers.foreach(_.replicaManager.checkpointHighWatermarks())
@@ -166,7 +166,7 @@ class LogRecoveryTest extends ZooKeeperTestHarness {
 
     // give some time for follower 1 to record leader HW of 60
     TestUtils.waitUntilTrue(() =>
-      server2.replicaManager.localReplica(topicPartition).get.highWatermark.messageOffset == hw,
+      server2.replicaManager.localReplica(topicPartition).get.highWatermark == hw,
       "Failed to update high watermark for follower after timeout")
     // shutdown the servers to allow the hw to be checkpointed
     servers.foreach(_.shutdown())
@@ -180,7 +180,7 @@ class LogRecoveryTest extends ZooKeeperTestHarness {
     val hw = 20L
     // give some time for follower 1 to record leader HW of 600
     TestUtils.waitUntilTrue(() =>
-      server2.replicaManager.localReplica(topicPartition).get.highWatermark.messageOffset == hw,
+      server2.replicaManager.localReplica(topicPartition).get.highWatermark == hw,
       "Failed to update high watermark for follower after timeout")
     // shutdown the servers to allow the hw to be checkpointed
     servers.foreach(_.shutdown())
@@ -199,7 +199,7 @@ class LogRecoveryTest extends ZooKeeperTestHarness {
 
     // allow some time for the follower to get the leader HW
     TestUtils.waitUntilTrue(() =>
-      server2.replicaManager.localReplica(topicPartition).get.highWatermark.messageOffset == hw,
+      server2.replicaManager.localReplica(topicPartition).get.highWatermark == hw,
       "Failed to update high watermark for follower after timeout")
     // kill the server hosting the preferred replica
     server1.shutdown()
@@ -230,7 +230,7 @@ class LogRecoveryTest extends ZooKeeperTestHarness {
       "Failed to create replica in follower after timeout")
     // allow some time for the follower to get the leader HW
     TestUtils.waitUntilTrue(() =>
-      server1.replicaManager.localReplica(topicPartition).get.highWatermark.messageOffset == hw,
+      server1.replicaManager.localReplica(topicPartition).get.highWatermark == hw,
       "Failed to update high watermark for follower after timeout")
     // shutdown the servers to allow the hw to be checkpointed
     servers.foreach(_.shutdown())
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
index a51641a..5400f86 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
@@ -87,7 +87,7 @@ class ReplicaFetcherThreadTest {
 
     //Stubs
     expect(replica.logEndOffset).andReturn(0).anyTimes()
-    expect(replica.highWatermark).andReturn(new LogOffsetMetadata(0)).anyTimes()
+    expect(replica.highWatermark).andReturn(0L).anyTimes()
     expect(replica.latestEpoch).andReturn(Some(leaderEpoch)).once()
     expect(replica.latestEpoch).andReturn(Some(leaderEpoch)).once()
     expect(replica.latestEpoch).andReturn(None).once()  // t2p1 doesnt support epochs
@@ -218,7 +218,7 @@ class ReplicaFetcherThreadTest {
 
     //Stubs
     expect(replica.logEndOffset).andReturn(0).anyTimes()
-    expect(replica.highWatermark).andReturn(new LogOffsetMetadata(0)).anyTimes()
+    expect(replica.highWatermark).andReturn(0L).anyTimes()
     expect(replica.latestEpoch).andReturn(Some(leaderEpoch)).anyTimes()
     expect(replica.endOffsetForEpoch(leaderEpoch)).andReturn(
       Some(OffsetAndEpoch(0, leaderEpoch))).anyTimes()
@@ -280,7 +280,7 @@ class ReplicaFetcherThreadTest {
     //Stubs
     expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).anyTimes()
     expect(replica.logEndOffset).andReturn(initialLEO).anyTimes()
-    expect(replica.highWatermark).andReturn(new LogOffsetMetadata(initialLEO - 1)).anyTimes()
+    expect(replica.highWatermark).andReturn(initialLEO - 1).anyTimes()
     expect(replica.latestEpoch).andReturn(Some(leaderEpoch)).anyTimes()
     expect(replica.endOffsetForEpoch(leaderEpoch)).andReturn(
       Some(OffsetAndEpoch(initialLEO, leaderEpoch))).anyTimes()
@@ -329,7 +329,7 @@ class ReplicaFetcherThreadTest {
     //Stubs
     expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).anyTimes()
     expect(replica.logEndOffset).andReturn(initialLEO).anyTimes()
-    expect(replica.highWatermark).andReturn(new LogOffsetMetadata(initialLEO - 3)).anyTimes()
+    expect(replica.highWatermark).andReturn(initialLEO - 3).anyTimes()
     expect(replica.latestEpoch).andReturn(Some(leaderEpochAtFollower)).anyTimes()
     expect(replica.endOffsetForEpoch(leaderEpochAtLeader)).andReturn(None).anyTimes()
     expect(replicaManager.logManager).andReturn(logManager).anyTimes()
@@ -379,7 +379,7 @@ class ReplicaFetcherThreadTest {
     // Stubs
     expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).anyTimes()
     expect(replica.logEndOffset).andReturn(initialLEO).anyTimes()
-    expect(replica.highWatermark).andReturn(new LogOffsetMetadata(initialLEO - 2)).anyTimes()
+    expect(replica.highWatermark).andReturn(initialLEO - 2).anyTimes()
     expect(replica.latestEpoch).andReturn(Some(5)).anyTimes()
     expect(replica.endOffsetForEpoch(4)).andReturn(
       Some(OffsetAndEpoch(120, 3))).anyTimes()
@@ -450,7 +450,7 @@ class ReplicaFetcherThreadTest {
     // Stubs
     expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).anyTimes()
     expect(replica.logEndOffset).andReturn(initialLEO).anyTimes()
-    expect(replica.highWatermark).andReturn(new LogOffsetMetadata(initialLEO - 2)).anyTimes()
+    expect(replica.highWatermark).andReturn(initialLEO - 2).anyTimes()
     expect(replica.latestEpoch).andReturn(Some(5)).anyTimes()
     expect(replica.endOffsetForEpoch(4)).andReturn(
       Some(OffsetAndEpoch(120, 3))).anyTimes()
@@ -512,7 +512,7 @@ class ReplicaFetcherThreadTest {
     //Stubs
     expect(partition.truncateTo(capture(truncated), anyBoolean())).anyTimes()
     expect(replica.logEndOffset).andReturn(initialLeo).anyTimes()
-    expect(replica.highWatermark).andReturn(new LogOffsetMetadata(initialFetchOffset)).anyTimes()
+    expect(replica.highWatermark).andReturn(initialFetchOffset).anyTimes()
     expect(replica.latestEpoch).andReturn(Some(5))
     expect(replicaManager.logManager).andReturn(logManager).anyTimes()
     expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes()
@@ -554,7 +554,7 @@ class ReplicaFetcherThreadTest {
     val initialLeo = 300
 
     //Stubs
-    expect(replica.highWatermark).andReturn(new LogOffsetMetadata(highWaterMark)).anyTimes()
+    expect(replica.highWatermark).andReturn(highWaterMark).anyTimes()
     expect(partition.truncateTo(capture(truncated), anyBoolean())).anyTimes()
     expect(replica.logEndOffset).andReturn(initialLeo).anyTimes()
     expect(replica.latestEpoch).andReturn(Some(leaderEpoch)).anyTimes()
@@ -611,7 +611,7 @@ class ReplicaFetcherThreadTest {
     //Stub return values
     expect(partition.truncateTo(0L, false)).times(2)
     expect(replica.logEndOffset).andReturn(0).anyTimes()
-    expect(replica.highWatermark).andReturn(new LogOffsetMetadata(0)).anyTimes()
+    expect(replica.highWatermark).andReturn(0).anyTimes()
     expect(replica.latestEpoch).andReturn(Some(leaderEpoch)).anyTimes()
     expect(replica.endOffsetForEpoch(leaderEpoch)).andReturn(
       Some(OffsetAndEpoch(0, leaderEpoch))).anyTimes()
@@ -662,7 +662,7 @@ class ReplicaFetcherThreadTest {
     //Stub return values
     expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).once
     expect(replica.logEndOffset).andReturn(initialLEO).anyTimes()
-    expect(replica.highWatermark).andReturn(new LogOffsetMetadata(initialLEO - 2)).anyTimes()
+    expect(replica.highWatermark).andReturn(initialLEO - 2).anyTimes()
     expect(replica.latestEpoch).andReturn(Some(5)).anyTimes()
     expect(replica.endOffsetForEpoch(5)).andReturn(Some(OffsetAndEpoch(initialLEO, 5))).anyTimes()
     expect(replicaManager.logManager).andReturn(logManager).anyTimes()
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
index c2d92df..d298003 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
@@ -239,17 +239,17 @@ class ReplicaManagerQuotasTest {
     for ((p, _) <- fetchInfo) {
       val partition = replicaManager.createPartition(p)
       val leaderReplica = new Replica(configs.head.brokerId, p, time, 0, Some(log))
-      leaderReplica.highWatermark = new LogOffsetMetadata(5)
+      leaderReplica.highWatermark = 5
       partition.leaderReplicaIdOpt = Some(leaderReplica.brokerId)
       val followerReplica = new Replica(configs.last.brokerId, p, time, 0, Some(log))
       val allReplicas = Set(leaderReplica, followerReplica)
       allReplicas.foreach(partition.addReplicaIfNotExists)
       if (bothReplicasInSync) {
         partition.inSyncReplicas = allReplicas
-        followerReplica.highWatermark = new LogOffsetMetadata(5)
+        followerReplica.highWatermark = 5
       } else {
         partition.inSyncReplicas = Set(leaderReplica)
-        followerReplica.highWatermark = new LogOffsetMetadata(0)
+        followerReplica.highWatermark = 0
       }
     }
   }
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 7c14cd2..59248f0 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -34,7 +34,7 @@ import org.I0Itec.zkclient.ZkClient
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record._
-import org.apache.kafka.common.requests.{EpochEndOffset, IsolationLevel, LeaderAndIsrRequest}
+import org.apache.kafka.common.requests.{EpochEndOffset, FetchRequest, IsolationLevel, LeaderAndIsrRequest}
 import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
 import org.apache.kafka.common.requests.FetchRequest.PartitionData
 import org.apache.kafka.common.requests.FetchResponse.AbortedTransaction
@@ -454,6 +454,92 @@ class ReplicaManagerTest {
     }
   }
 
+  @Test
+  def testFollowerStateNotUpdatedIfLogReadFails(): Unit = {
+    val maxFetchBytes = 1024 * 1024
+    val aliveBrokersIds = Seq(0, 1)
+    val leaderEpoch = 5
+    val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer, aliveBrokersIds)
+    try {
+      val tp = new TopicPartition(topic, 0)
+      val replicas = aliveBrokersIds.toList.map(Int.box).asJava
+
+      // Broker 0 becomes leader of the partition
+      val leaderAndIsrPartitionState = new LeaderAndIsrRequest.PartitionState(0, 0, leaderEpoch,
+        replicas, 0, replicas, true)
+      val leaderAndIsrRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch,
+        Map(tp -> leaderAndIsrPartitionState).asJava,
+        Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build()
+      val leaderAndIsrResponse = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest, (_, _) => ())
+      assertEquals(Errors.NONE, leaderAndIsrResponse.error)
+
+      // Follower replica state is initialized, but initial state is not known
+      assertTrue(replicaManager.nonOfflinePartition(tp).isDefined)
+      val partition = replicaManager.nonOfflinePartition(tp).get
+
+      assertTrue(partition.getReplica(1).isDefined)
+      val followerReplica = partition.getReplica(1).get
+      assertEquals(None, followerReplica.log)
+      assertEquals(-1L, followerReplica.logStartOffset)
+      assertEquals(-1L, followerReplica.logEndOffset)
+
+      // Leader appends some data
+      for (i <- 1 to 5) {
+        appendRecords(replicaManager, tp, TestUtils.singletonRecords(s"message $i".getBytes)).onFire { response =>
+          assertEquals(Errors.NONE, response.error)
+        }
+      }
+
+      // We receive one valid request from the follower and replica state is updated
+      var successfulFetch: Option[FetchPartitionData] = None
+      def callback(response: Seq[(TopicPartition, FetchPartitionData)]): Unit = {
+        successfulFetch = response.headOption.filter(_._1 == tp).map(_._2)
+      }
+
+      val validFetchPartitionData = new FetchRequest.PartitionData(0L, 0L, maxFetchBytes,
+        Optional.of(leaderEpoch))
+
+      replicaManager.fetchMessages(
+        timeout = 0L,
+        replicaId = 1,
+        fetchMinBytes = 1,
+        fetchMaxBytes = maxFetchBytes,
+        hardMaxBytesLimit = false,
+        fetchInfos = Seq(tp -> validFetchPartitionData),
+        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
+        responseCallback = callback
+      )
+
+      assertTrue(successfulFetch.isDefined)
+      assertEquals(0L, followerReplica.logStartOffset)
+      assertEquals(0L, followerReplica.logEndOffset)
+
+
+      // Next we receive an invalid request with a higher fetch offset, but an old epoch.
+      // We expect that the replica state does not get updated.
+      val invalidFetchPartitionData = new FetchRequest.PartitionData(3L, 0L, maxFetchBytes,
+        Optional.of(leaderEpoch - 1))
+
+      replicaManager.fetchMessages(
+        timeout = 0L,
+        replicaId = 1,
+        fetchMinBytes = 1,
+        fetchMaxBytes = maxFetchBytes,
+        hardMaxBytesLimit = false,
+        fetchInfos = Seq(tp -> invalidFetchPartitionData),
+        isolationLevel = IsolationLevel.READ_UNCOMMITTED,
+        responseCallback = callback
+      )
+
+      assertTrue(successfulFetch.isDefined)
+      assertEquals(0L, followerReplica.logStartOffset)
+      assertEquals(0L, followerReplica.logEndOffset)
+
+    } finally {
+      replicaManager.shutdown(checkpointHW = false)
+    }
+  }
+
   /**
    * If a follower sends a fetch request for 2 partitions and it's no longer the follower for one of them, the other
    * partition should not be affected.
@@ -525,12 +611,12 @@ class ReplicaManagerTest {
       )
       val tp0Replica = replicaManager.localReplica(tp0)
       assertTrue(tp0Replica.isDefined)
-      assertEquals("hw should be incremented", 1, tp0Replica.get.highWatermark.messageOffset)
+      assertEquals("hw should be incremented", 1, tp0Replica.get.highWatermark)
 
       replicaManager.localReplica(tp1)
       val tp1Replica = replicaManager.localReplica(tp1)
       assertTrue(tp1Replica.isDefined)
-      assertEquals("hw should not be incremented", 0, tp1Replica.get.highWatermark.messageOffset)
+      assertEquals("hw should not be incremented", 0, tp1Replica.get.highWatermark)
 
     } finally {
       replicaManager.shutdown(checkpointHW = false)
diff --git a/core/src/test/scala/unit/kafka/server/SimpleFetchTest.scala b/core/src/test/scala/unit/kafka/server/SimpleFetchTest.scala
index 41c6b3e..e7e6e8f 100644
--- a/core/src/test/scala/unit/kafka/server/SimpleFetchTest.scala
+++ b/core/src/test/scala/unit/kafka/server/SimpleFetchTest.scala
@@ -121,20 +121,17 @@ class SimpleFetchTest {
 
     // create the leader replica with the local log
     val leaderReplica = new Replica(configs.head.brokerId, partition.topicPartition, time, 0, Some(log))
-    leaderReplica.highWatermark = new LogOffsetMetadata(partitionHW)
+    leaderReplica.highWatermark = partitionHW
     partition.leaderReplicaIdOpt = Some(leaderReplica.brokerId)
 
     // create the follower replica with defined log end offset
     val followerReplica= new Replica(configs(1).brokerId, partition.topicPartition, time)
     val leo = new LogOffsetMetadata(followerLEO, 0L, followerLEO.toInt)
-    followerReplica.updateLogReadResult(new LogReadResult(info = FetchDataInfo(leo, MemoryRecords.EMPTY),
-                                                          highWatermark = leo.messageOffset,
-                                                          leaderLogStartOffset = 0L,
-                                                          leaderLogEndOffset = leo.messageOffset,
-                                                          followerLogStartOffset = 0L,
-                                                          fetchTimeMs = time.milliseconds,
-                                                          readSize = -1,
-                                                          lastStableOffset = None))
+    followerReplica.updateFetchState(
+      followerFetchOffsetMetadata = leo,
+      followerStartOffset = 0L,
+      followerFetchTimeMs= time.milliseconds,
+      leaderEndOffset = leo.messageOffset)
 
     // add both of them to ISR
     val allReplicas = List(leaderReplica, followerReplica)