You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by da...@apache.org on 2023/02/20 07:26:31 UTC

[kafka] branch trunk updated: KAFKA-14673; Add high watermark listener to Partition/Log layers (#13196)

This is an automated email from the ASF dual-hosted git repository.

dajac pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 4dd27a9f211 KAFKA-14673; Add high watermark listener to Partition/Log layers (#13196)
4dd27a9f211 is described below

commit 4dd27a9f2116edfd22b4dc7368a62fedef28fe3b
Author: David Jacot <dj...@confluent.io>
AuthorDate: Mon Feb 20 08:26:17 2023 +0100

    KAFKA-14673; Add high watermark listener to Partition/Log layers (#13196)
    
    In the context of KIP-848, we implements are new group coordinator in Java. This new coordinator follows the architecture of the new quorum controller. It is basically a replicated state machine that writes to the log and commits its internal state when the writes are committed. At the moment, the only way to know when a write is committed is to use a delayed fetch. This is not ideal in our context because a delayed fetch can be completed before the write is actually committed to the log.
    
    This patch proposes to introduce a high watermark listener to the Partition/Log layers. This will allow the new group coordinator to simply listen to changes and commit its state based on this. This mechanism is simpler and lighter as well.
    
    Reviewers: Christo Lolov <lo...@amazon.com>, Justine Olshan <jo...@confluent.io>, Jason Gustafson <ja...@confluent.io>
---
 core/src/main/scala/kafka/cluster/Partition.scala  | 111 +++++++-
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  19 +-
 .../main/scala/kafka/server/ReplicaManager.scala   |  32 ++-
 .../scala/unit/kafka/cluster/PartitionTest.scala   | 291 +++++++++++++++++++++
 .../test/scala/unit/kafka/log/LogTestUtils.scala   |   8 +-
 .../test/scala/unit/kafka/log/UnifiedLogTest.scala |  86 +++++-
 .../unit/kafka/server/ReplicaManagerTest.scala     |  89 +++++++
 .../storage/internals/log/LogOffsetsListener.java  |  36 +++
 8 files changed, 647 insertions(+), 25 deletions(-)

diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala
index f27c0f9dc23..310d231c54b 100755
--- a/core/src/main/scala/kafka/cluster/Partition.scala
+++ b/core/src/main/scala/kafka/cluster/Partition.scala
@@ -18,7 +18,7 @@ package kafka.cluster
 
 import java.util.concurrent.locks.ReentrantReadWriteLock
 import java.util.Optional
-import java.util.concurrent.CompletableFuture
+import java.util.concurrent.{CompletableFuture, CopyOnWriteArrayList}
 import kafka.api.LeaderAndIsr
 import kafka.common.UnexpectedAppendOffsetException
 import kafka.controller.{KafkaController, StateChangeLogger}
@@ -44,11 +44,41 @@ import org.apache.kafka.common.utils.Time
 import org.apache.kafka.common.{IsolationLevel, TopicPartition, Uuid}
 import org.apache.kafka.metadata.LeaderRecoveryState
 import org.apache.kafka.server.common.MetadataVersion
-import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchDataInfo, FetchIsolation, FetchParams, LogOffsetMetadata}
+import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchDataInfo, FetchIsolation, FetchParams, LogOffsetMetadata, LogOffsetsListener}
 
 import scala.collection.{Map, Seq}
 import scala.jdk.CollectionConverters._
 
+/**
+ * Listener receives notification from an Online Partition.
+ *
+ * A listener can be (re-)registered to an Online partition only. The listener
+ * is notified as long as the partition remains Online. When the partition fails
+ * or is deleted, respectively `onFailed` or `onDeleted` are called once. No further
+ * notifications are sent after this point on.
+ *
+ * Note that the callbacks are executed in the thread that triggers the change
+ * AND that locks may be held during their execution. They are meant to be used
+ * as notification mechanism only.
+ */
+trait PartitionListener {
+  /**
+   * Called when the Log increments its high watermark.
+   */
+  def onHighWatermarkUpdated(partition: TopicPartition, offset: Long): Unit = {}
+
+  /**
+   * Called when the Partition (or replica) on this broker has a failure (e.g. goes offline).
+   */
+  def onFailed(partition: TopicPartition): Unit = {}
+
+  /**
+   * Called when the Partition (or replica) on this broker is deleted. Note that it does not mean
+   * that the partition was deleted but only that this broker does not host a replica of it any more.
+   */
+  def onDeleted(partition: TopicPartition): Unit = {}
+}
+
 trait AlterPartitionListener {
   def markIsrExpand(): Unit
   def markIsrShrink(): Unit
@@ -285,6 +315,17 @@ class Partition(val topicPartition: TopicPartition,
   // If ReplicaAlterLogDir command is in progress, this is future location of the log
   @volatile var futureLog: Option[UnifiedLog] = None
 
+  // Partition listeners
+  private val listeners = new CopyOnWriteArrayList[PartitionListener]()
+
+  private val logOffsetsListener = new LogOffsetsListener {
+    override def onHighWatermarkUpdated(offset: Long): Unit = {
+      listeners.forEach { listener =>
+        listener.onHighWatermarkUpdated(topicPartition, offset)
+      }
+    }
+  }
+
   /* Epoch of the controller that last changed the leader. This needs to be initialized correctly upon broker startup.
    * One way of doing that is through the controller's start replica state change command. When a new broker starts up
    * the controller sends it a start replica command containing the leader for each partition that the broker hosts.
@@ -318,6 +359,24 @@ class Partition(val topicPartition: TopicPartition,
 
   def inSyncReplicaIds: Set[Int] = partitionState.isr
 
+  def maybeAddListener(listener: PartitionListener): Boolean = {
+    inReadLock(leaderIsrUpdateLock) {
+      // `log` is set to `None` when the partition is failed or deleted.
+      log match {
+        case Some(_) =>
+          listeners.add(listener)
+          true
+
+        case None =>
+          false
+      }
+    }
+  }
+
+  def removeListener(listener: PartitionListener): Unit = {
+    listeners.remove(listener)
+  }
+
   /**
     * Create the future replica if 1) the current replica is not in the given log directory and 2) the future replica
     * does not exist. This method assumes that the current replica has already been created.
@@ -387,6 +446,7 @@ class Partition(val topicPartition: TopicPartition,
     var maybeLog: Option[UnifiedLog] = None
     try {
       val log = logManager.getOrCreateLog(topicPartition, isNew, isFutureReplica, topicId)
+      if (!isFutureReplica) log.setLogOffsetsListener(logOffsetsListener)
       maybeLog = Some(log)
       updateHighWatermark(log)
       log
@@ -469,10 +529,12 @@ class Partition(val topicPartition: TopicPartition,
 
   // Visible for testing -- Used by unit tests to set log for this partition
   def setLog(log: UnifiedLog, isFutureLog: Boolean): Unit = {
-    if (isFutureLog)
+    if (isFutureLog) {
       futureLog = Some(log)
-    else
+    } else {
+      log.setLogOffsetsListener(logOffsetsListener)
       this.log = Some(log)
+    }
   }
 
   /**
@@ -517,6 +579,7 @@ class Partition(val topicPartition: TopicPartition,
             case Some(futurePartitionLog) =>
               if (log.exists(_.logEndOffset == futurePartitionLog.logEndOffset)) {
                 logManager.replaceCurrentWithFutureLog(topicPartition)
+                futurePartitionLog.setLogOffsetsListener(logOffsetsListener)
                 log = futureLog
                 removeFutureLocalReplica(false)
                 true
@@ -540,17 +603,41 @@ class Partition(val topicPartition: TopicPartition,
   def delete(): Unit = {
     // need to hold the lock to prevent appendMessagesToLeader() from hitting I/O exceptions due to log being deleted
     inWriteLock(leaderIsrUpdateLock) {
-      remoteReplicasMap.clear()
-      assignmentState = SimpleAssignmentState(Seq.empty)
-      log = None
-      futureLog = None
-      partitionState = CommittedPartitionState(Set.empty, LeaderRecoveryState.RECOVERED)
-      leaderReplicaIdOpt = None
-      leaderEpochStartOffsetOpt = None
-      Partition.removeMetrics(topicPartition)
+      clear()
+
+      listeners.forEach { listener =>
+        listener.onDeleted(topicPartition)
+      }
+      listeners.clear()
     }
   }
 
+  /**
+   * Fail the partition. This is called by the ReplicaManager when the partition
+   * transitions to Offline.
+   */
+  def markOffline(): Unit = {
+    inWriteLock(leaderIsrUpdateLock) {
+      clear()
+
+      listeners.forEach { listener =>
+        listener.onFailed(topicPartition)
+      }
+      listeners.clear()
+    }
+  }
+
+  private def clear(): Unit = {
+    remoteReplicasMap.clear()
+    assignmentState = SimpleAssignmentState(Seq.empty)
+    log = None
+    futureLog = None
+    partitionState = CommittedPartitionState(Set.empty, LeaderRecoveryState.RECOVERED)
+    leaderReplicaIdOpt = None
+    leaderEpochStartOffsetOpt = None
+    Partition.removeMetrics(topicPartition)
+  }
+
   def getLeaderEpoch: Int = this.leaderEpoch
 
   def getPartitionEpoch: Int = this.partitionEpoch
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala
index 712774cc05a..81c241bfce7 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -45,7 +45,7 @@ import org.apache.kafka.server.record.BrokerCompressionType
 import org.apache.kafka.server.util.Scheduler
 import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, BatchMetadata, CompletedTxn, EpochEntry, FetchDataInfo, FetchIsolation, LastRecord, LogConfig, LogDirFailureChannel, LogOffsetMetadata, LogValidator, ProducerAppendInfo, ProducerStateManager, ProducerStateManagerConfig}
+import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, BatchMetadata, CompletedTxn, EpochEntry, FetchDataInfo, FetchIsolation, LastRecord, LogConfig, LogDirFailureChannel, LogOffsetMetadata, LogOffsetsListener, LogValidator, ProducerAppendInfo, ProducerStateManager, ProducerStateManagerConfig}
 
 import scala.annotation.nowarn
 import scala.collection.mutable.ListBuffer
@@ -247,7 +247,8 @@ class UnifiedLog(@volatile var logStartOffset: Long,
                  @volatile private var _topicId: Option[Uuid],
                  val keepPartitionMetadataFile: Boolean,
                  val remoteStorageSystemEnable: Boolean = false,
-                 remoteLogManager: Option[RemoteLogManager] = None) extends Logging with KafkaMetricsGroup {
+                 remoteLogManager: Option[RemoteLogManager] = None,
+                 @volatile private var logOffsetsListener: LogOffsetsListener = LogOffsetsListener.NO_OP_OFFSETS_LISTENER) extends Logging with KafkaMetricsGroup {
 
   import kafka.log.UnifiedLog._
 
@@ -288,6 +289,12 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     updateLogStartOffset(logStartOffset)
     maybeIncrementFirstUnstableOffset()
     initializeTopicId()
+
+    logOffsetsListener.onHighWatermarkUpdated(highWatermarkMetadata.messageOffset)
+  }
+
+  def setLogOffsetsListener(listener: LogOffsetsListener): Unit = {
+    logOffsetsListener = listener
   }
 
   def remoteLogEnabled(): Boolean = {
@@ -486,6 +493,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
 
       highWatermarkMetadata = newHighWatermark
       producerStateManager.onHighWatermarkUpdated(newHighWatermark.messageOffset)
+      logOffsetsListener.onHighWatermarkUpdated(newHighWatermark.messageOffset)
       maybeIncrementFirstUnstableOffset()
     }
     trace(s"Setting high watermark $newHighWatermark")
@@ -712,6 +720,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
   def close(): Unit = {
     debug("Closing log")
     lock synchronized {
+      logOffsetsListener = LogOffsetsListener.NO_OP_OFFSETS_LISTENER
       maybeFlushMetadataFile()
       localLog.checkIfMemoryMappedBufferClosed()
       producerExpireCheck.cancel(true)
@@ -1892,7 +1901,8 @@ object UnifiedLog extends Logging {
             keepPartitionMetadataFile: Boolean,
             numRemainingSegments: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int],
             remoteStorageSystemEnable: Boolean = false,
-            remoteLogManager: Option[RemoteLogManager] = None): UnifiedLog = {
+            remoteLogManager: Option[RemoteLogManager] = None,
+            logOffsetsListener: LogOffsetsListener = LogOffsetsListener.NO_OP_OFFSETS_LISTENER): UnifiedLog = {
     // create the log directory if it doesn't exist
     Files.createDirectories(dir.toPath)
     val topicPartition = UnifiedLog.parseTopicPartitionName(dir)
@@ -1931,7 +1941,8 @@ object UnifiedLog extends Logging {
       topicId,
       keepPartitionMetadataFile,
       remoteStorageSystemEnable,
-      remoteLogManager)
+      remoteLogManager,
+      logOffsetsListener)
   }
 
   def logFile(dir: File, offset: Long, suffix: String = ""): File = LocalLog.logFile(dir, offset, suffix)
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index b2bf076c98d..9ef2c3ab3af 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean
 import java.util.concurrent.locks.Lock
 import com.yammer.metrics.core.Meter
 import kafka.api._
-import kafka.cluster.{BrokerEndPoint, Partition}
+import kafka.cluster.{BrokerEndPoint, Partition, PartitionListener}
 import kafka.controller.{KafkaController, StateChangeLogger}
 import kafka.log.{LeaderHwChange, LogAppendInfo, LogManager, LogReadInfo, UnifiedLog}
 import kafka.log.remote.RemoteLogManager
@@ -334,6 +334,29 @@ class ReplicaManager(val config: KafkaConfig,
     topicPartitions.foreach(tp => delayedFetchPurgatory.checkAndComplete(TopicPartitionOperationKey(tp)))
   }
 
+  /**
+   * Registers the provided listener to the partition iff the partition is online.
+   */
+  def maybeAddListener(partition: TopicPartition, listener: PartitionListener): Boolean = {
+    getPartition(partition) match {
+      case HostedPartition.Online(partition) =>
+        partition.maybeAddListener(listener)
+      case _ =>
+        false
+    }
+  }
+
+  /**
+   * Removes the provided listener from the partition.
+   */
+  def removeListener(partition: TopicPartition, listener: PartitionListener): Unit = {
+    getPartition(partition) match {
+      case HostedPartition.Online(partition) =>
+        partition.removeListener(listener)
+      case _ => // Ignore
+    }
+  }
+
   def stopReplicas(correlationId: Int,
                    controllerId: Int,
                    controllerEpoch: Int,
@@ -1843,8 +1866,11 @@ class ReplicaManager(val config: KafkaConfig,
   }
 
   def markPartitionOffline(tp: TopicPartition): Unit = replicaStateChangeLock synchronized {
-    allPartitions.put(tp, HostedPartition.Offline)
-    Partition.removeMetrics(tp)
+    allPartitions.put(tp, HostedPartition.Offline) match {
+      case HostedPartition.Online(partition) =>
+        partition.markOffline()
+      case _ => // Nothing
+    }
   }
 
   /**
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index aa6397c8819..e8ac6f50ace 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -63,6 +63,49 @@ import scala.compat.java8.OptionConverters._
 import scala.jdk.CollectionConverters._
 
 object PartitionTest {
+  class MockPartitionListener extends PartitionListener {
+    private var highWatermark: Long = -1L
+    private var failed: Boolean = false
+    private var deleted: Boolean = false
+
+    override def onHighWatermarkUpdated(partition: TopicPartition, offset: Long): Unit = {
+      highWatermark = offset
+    }
+
+    override def onFailed(partition: TopicPartition): Unit = {
+      failed = true
+    }
+
+    override def onDeleted(partition: TopicPartition): Unit = {
+      deleted = true
+    }
+
+    private def clear(): Unit = {
+      highWatermark = -1L
+      failed = false
+      deleted = false
+    }
+
+    /**
+     * Verifies the callbacks that have been triggered since the last
+     * verification. Values different than `-1` are the ones that have
+     * been updated.
+     */
+    def verify(
+      expectedHighWatermark: Long = -1L,
+      expectedFailed: Boolean = false,
+      expectedDeleted: Boolean = false
+    ): Unit = {
+      assertEquals(expectedHighWatermark, highWatermark,
+        "Unexpected high watermark")
+      assertEquals(expectedFailed, failed,
+        "Unexpected failed")
+      assertEquals(expectedDeleted, deleted,
+        "Unexpected deleted")
+      clear()
+    }
+  }
+
   def followerFetchParams(
     replicaId: Int,
     maxWaitMs: Long = 0L,
@@ -2798,6 +2841,254 @@ class PartitionTest extends AbstractPartitionTest {
     assertEquals(replicas, partition.assignmentState.replicas)
   }
 
+  @Test
+  def testAddAndRemoveListeners(): Unit = {
+    partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, topicId = None)
+
+    partition.makeLeader(
+      new LeaderAndIsrPartitionState()
+        .setControllerEpoch(0)
+        .setLeader(brokerId)
+        .setLeaderEpoch(0)
+        .setIsr(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setReplicas(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setPartitionEpoch(1)
+        .setIsNew(true),
+      offsetCheckpoints,
+      topicId = None)
+
+    val listener1 = new MockPartitionListener()
+    val listener2 = new MockPartitionListener()
+
+    assertTrue(partition.maybeAddListener(listener1))
+    listener1.verify()
+
+    partition.appendRecordsToLeader(
+      records = TestUtils.records(List(new SimpleRecord("k1".getBytes, "v1".getBytes))),
+      origin = AppendOrigin.CLIENT,
+      requiredAcks = 0,
+      requestLocal = RequestLocal.NoCaching
+    )
+
+    listener1.verify()
+    listener2.verify()
+
+    assertTrue(partition.maybeAddListener(listener2))
+    listener2.verify()
+
+    partition.appendRecordsToLeader(
+      records = TestUtils.records(List(new SimpleRecord("k2".getBytes, "v2".getBytes))),
+      origin = AppendOrigin.CLIENT,
+      requiredAcks = 0,
+      requestLocal = RequestLocal.NoCaching
+    )
+
+    fetchFollower(
+      partition = partition,
+      replicaId = brokerId + 1,
+      fetchOffset = partition.localLogOrException.logEndOffset
+    )
+
+    listener1.verify(expectedHighWatermark = partition.localLogOrException.logEndOffset)
+    listener2.verify(expectedHighWatermark = partition.localLogOrException.logEndOffset)
+
+    partition.removeListener(listener1)
+
+    partition.appendRecordsToLeader(
+      records = TestUtils.records(List(new SimpleRecord("k3".getBytes, "v3".getBytes))),
+      origin = AppendOrigin.CLIENT,
+      requiredAcks = 0,
+      requestLocal = RequestLocal.NoCaching
+    )
+
+    fetchFollower(
+      partition = partition,
+      replicaId = brokerId + 1,
+      fetchOffset = partition.localLogOrException.logEndOffset
+    )
+
+    listener1.verify()
+    listener2.verify(expectedHighWatermark = partition.localLogOrException.logEndOffset)
+  }
+
+  @Test
+  def testAddListenerFailsWhenPartitionIsDeleted(): Unit = {
+    partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, topicId = None)
+
+    partition.makeLeader(
+      new LeaderAndIsrPartitionState()
+        .setControllerEpoch(0)
+        .setLeader(brokerId)
+        .setLeaderEpoch(0)
+        .setIsr(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setReplicas(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setPartitionEpoch(1)
+        .setIsNew(true),
+      offsetCheckpoints,
+      topicId = None)
+
+    partition.delete()
+
+    assertFalse(partition.maybeAddListener(new MockPartitionListener()))
+  }
+
+  @Test
+  def testPartitionListenerWhenLogOffsetsChanged(): Unit = {
+    partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, topicId = None)
+
+    partition.makeLeader(
+      new LeaderAndIsrPartitionState()
+        .setControllerEpoch(0)
+        .setLeader(brokerId)
+        .setLeaderEpoch(0)
+        .setIsr(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setReplicas(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setPartitionEpoch(1)
+        .setIsNew(true),
+      offsetCheckpoints,
+      topicId = None)
+
+    val listener = new MockPartitionListener()
+    assertTrue(partition.maybeAddListener(listener))
+    listener.verify()
+
+    partition.appendRecordsToLeader(
+      records = TestUtils.records(List(new SimpleRecord("k1".getBytes, "v1".getBytes))),
+      origin = AppendOrigin.CLIENT,
+      requiredAcks = 0,
+      requestLocal = RequestLocal.NoCaching
+    )
+
+    listener.verify()
+
+    fetchFollower(
+      partition = partition,
+      replicaId = brokerId + 1,
+      fetchOffset = partition.localLogOrException.logEndOffset
+    )
+
+    listener.verify(expectedHighWatermark = partition.localLogOrException.logEndOffset)
+
+    partition.truncateFullyAndStartAt(0L, false)
+
+    listener.verify(expectedHighWatermark = 0L)
+  }
+
+  @Test
+  def testPartitionListenerWhenPartitionFailed(): Unit = {
+    partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, topicId = None)
+
+    partition.makeLeader(
+      new LeaderAndIsrPartitionState()
+        .setControllerEpoch(0)
+        .setLeader(brokerId)
+        .setLeaderEpoch(0)
+        .setIsr(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setReplicas(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setPartitionEpoch(1)
+        .setIsNew(true),
+      offsetCheckpoints,
+      topicId = None)
+
+    val listener = new MockPartitionListener()
+    assertTrue(partition.maybeAddListener(listener))
+    listener.verify()
+
+    partition.markOffline()
+    listener.verify(expectedFailed = true)
+  }
+
+  @Test
+  def testPartitionListenerWhenPartitionIsDeleted(): Unit = {
+    partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, topicId = None)
+
+    partition.makeLeader(
+      new LeaderAndIsrPartitionState()
+        .setControllerEpoch(0)
+        .setLeader(brokerId)
+        .setLeaderEpoch(0)
+        .setIsr(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setReplicas(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setPartitionEpoch(1)
+        .setIsNew(true),
+      offsetCheckpoints,
+      topicId = None)
+
+    val listener = new MockPartitionListener()
+    assertTrue(partition.maybeAddListener(listener))
+    listener.verify()
+
+    partition.delete()
+    listener.verify(expectedDeleted = true)
+  }
+
+  @Test
+  def testPartitionListenerWhenCurrentIsReplacedWithFutureLog(): Unit = {
+    logManager.maybeUpdatePreferredLogDir(topicPartition, logDir1.getAbsolutePath)
+    partition.createLogIfNotExists(isNew = true, isFutureReplica = false, offsetCheckpoints, topicId = None)
+    assertTrue(partition.log.isDefined)
+
+    partition.makeLeader(
+      new LeaderAndIsrPartitionState()
+        .setControllerEpoch(0)
+        .setLeader(brokerId)
+        .setLeaderEpoch(0)
+        .setIsr(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setReplicas(List(brokerId, brokerId + 1).map(Int.box).asJava)
+        .setPartitionEpoch(1)
+        .setIsNew(true),
+      offsetCheckpoints,
+      topicId = None)
+
+    val listener = new MockPartitionListener()
+    assertTrue(partition.maybeAddListener(listener))
+    listener.verify()
+
+    val records = TestUtils.records(List(
+      new SimpleRecord("k1".getBytes, "v1".getBytes),
+      new SimpleRecord("k2".getBytes, "v2".getBytes)
+    ))
+
+    partition.appendRecordsToLeader(
+      records = records,
+      origin = AppendOrigin.CLIENT,
+      requiredAcks = 0,
+      requestLocal = RequestLocal.NoCaching
+    )
+
+    listener.verify()
+
+    logManager.maybeUpdatePreferredLogDir(topicPartition, logDir2.getAbsolutePath)
+    partition.maybeCreateFutureReplica(logDir2.getAbsolutePath, offsetCheckpoints)
+    assertTrue(partition.futureLog.isDefined)
+    val futureLog = partition.futureLog.get
+
+    partition.appendRecordsToFollowerOrFutureReplica(
+      records = records,
+      isFuture = true
+    )
+
+    listener.verify()
+
+    assertTrue(partition.maybeReplaceCurrentWithFutureReplica())
+    assertEquals(futureLog, partition.log.get)
+
+    partition.appendRecordsToLeader(
+      records = TestUtils.records(List(new SimpleRecord("k3".getBytes, "v3".getBytes))),
+      origin = AppendOrigin.CLIENT,
+      requiredAcks = 0,
+      requestLocal = RequestLocal.NoCaching
+    )
+
+    fetchFollower(
+      partition = partition,
+      replicaId = brokerId + 1,
+      fetchOffset = partition.localLogOrException.logEndOffset
+    )
+
+    listener.verify(expectedHighWatermark = partition.localLogOrException.logEndOffset)
+  }
+
   private def makeLeader(
     topicId: Option[Uuid],
     controllerEpoch: Int,
diff --git a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
index c69f8ab5bb1..5710a968f42 100644
--- a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
@@ -33,7 +33,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.server.util.Scheduler
 import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
-import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, FetchDataInfo, FetchIsolation, LazyIndex, LogConfig, LogDirFailureChannel, LogFileUtils, ProducerStateManager, ProducerStateManagerConfig, TransactionIndex}
+import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, FetchDataInfo, FetchIsolation, LazyIndex, LogConfig, LogDirFailureChannel, LogFileUtils, LogOffsetsListener, ProducerStateManager, ProducerStateManagerConfig, TransactionIndex}
 
 import scala.jdk.CollectionConverters._
 
@@ -94,7 +94,8 @@ object LogTestUtils {
                 keepPartitionMetadataFile: Boolean = true,
                 numRemainingSegments: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int],
                 remoteStorageSystemEnable: Boolean = false,
-                remoteLogManager: Option[RemoteLogManager] = None): UnifiedLog = {
+                remoteLogManager: Option[RemoteLogManager] = None,
+                logOffsetsListener: LogOffsetsListener = LogOffsetsListener.NO_OP_OFFSETS_LISTENER): UnifiedLog = {
     UnifiedLog(
       dir = dir,
       config = config,
@@ -112,7 +113,8 @@ object LogTestUtils {
       keepPartitionMetadataFile = keepPartitionMetadataFile,
       numRemainingSegments = numRemainingSegments,
       remoteStorageSystemEnable = remoteStorageSystemEnable,
-      remoteLogManager = remoteLogManager
+      remoteLogManager = remoteLogManager,
+      logOffsetsListener = logOffsetsListener
     )
   }
 
diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
index 5be226eaf0c..70a2284a7a4 100755
--- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
@@ -36,7 +36,7 @@ import org.apache.kafka.server.metrics.KafkaYammerMetrics
 import org.apache.kafka.server.util.{KafkaScheduler, Scheduler}
 import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, EpochEntry, FetchIsolation, LogConfig, LogFileUtils, LogOffsetMetadata, ProducerStateManager, ProducerStateManagerConfig, RecordValidationException}
+import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, EpochEntry, FetchIsolation, LogConfig, LogFileUtils, LogOffsetMetadata, LogOffsetsListener, ProducerStateManager, ProducerStateManagerConfig, RecordValidationException}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.ArgumentMatchers
@@ -3588,6 +3588,85 @@ class UnifiedLogTest {
       assertEquals(log.logStartOffset, log.localLogStartOffset())
     }
 
+  private class MockLogOffsetsListener extends LogOffsetsListener {
+    private var highWatermark: Long = -1L
+
+    override def onHighWatermarkUpdated(offset: Long): Unit = {
+      highWatermark = offset
+    }
+
+    private def clear(): Unit = {
+      highWatermark = -1L
+    }
+
+    /**
+     * Verifies the callbacks that have been triggered since the last
+     * verification. Values different than `-1` are the ones that have
+     * been updated.
+     */
+    def verify(expectedHighWatermark: Long = -1L): Unit = {
+      assertEquals(expectedHighWatermark, highWatermark,
+        "Unexpected high watermark")
+      clear()
+    }
+  }
+
+  @Test
+  def testLogOffsetsListener(): Unit = {
+    def records(offset: Long): MemoryRecords = TestUtils.records(List(
+      new SimpleRecord(mockTime.milliseconds, "a".getBytes, "value".getBytes),
+      new SimpleRecord(mockTime.milliseconds, "b".getBytes, "value".getBytes),
+      new SimpleRecord(mockTime.milliseconds, "c".getBytes, "value".getBytes)
+    ), baseOffset = offset, partitionLeaderEpoch = 0)
+
+    val listener = new MockLogOffsetsListener()
+    listener.verify()
+
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024)
+    val log = createLog(logDir, logConfig, logOffsetsListener = listener)
+
+    listener.verify(expectedHighWatermark = 0)
+
+    log.appendAsLeader(records(0), 0)
+    log.appendAsLeader(records(0), 0)
+
+    log.maybeIncrementHighWatermark(new LogOffsetMetadata(4))
+    listener.verify(expectedHighWatermark = 4)
+
+    log.truncateTo(3)
+    listener.verify(expectedHighWatermark = 3)
+
+    log.appendAsLeader(records(0), 0)
+    log.truncateFullyAndStartAt(4)
+    listener.verify(expectedHighWatermark = 4)
+  }
+
+  @Test
+  def testUpdateLogOffsetsListener(): Unit = {
+    def records(offset: Long): MemoryRecords = TestUtils.records(List(
+      new SimpleRecord(mockTime.milliseconds, "a".getBytes, "value".getBytes),
+      new SimpleRecord(mockTime.milliseconds, "b".getBytes, "value".getBytes),
+      new SimpleRecord(mockTime.milliseconds, "c".getBytes, "value".getBytes)
+    ), baseOffset = offset, partitionLeaderEpoch = 0)
+
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024)
+    val log = createLog(logDir, logConfig)
+
+    log.appendAsLeader(records(0), 0)
+    log.maybeIncrementHighWatermark(new LogOffsetMetadata(2))
+    log.maybeIncrementLogStartOffset(1, SegmentDeletion)
+
+    val listener = new MockLogOffsetsListener()
+    listener.verify()
+
+    log.setLogOffsetsListener(listener)
+    listener.verify() // it is still empty because we don't call the listener when it is set.
+
+    log.appendAsLeader(records(0), 0)
+    log.maybeIncrementHighWatermark(new LogOffsetMetadata(4))
+    listener.verify(expectedHighWatermark = 4)
+  }
+
   private def appendTransactionalToBuffer(buffer: ByteBuffer,
                                           producerId: Long,
                                           producerEpoch: Short,
@@ -3644,11 +3723,12 @@ class UnifiedLogTest {
                         topicId: Option[Uuid] = None,
                         keepPartitionMetadataFile: Boolean = true,
                         remoteStorageSystemEnable: Boolean = false,
-                        remoteLogManager: Option[RemoteLogManager] = None): UnifiedLog = {
+                        remoteLogManager: Option[RemoteLogManager] = None,
+                        logOffsetsListener: LogOffsetsListener = LogOffsetsListener.NO_OP_OFFSETS_LISTENER): UnifiedLog = {
     LogTestUtils.createLog(dir, config, brokerTopicStats, scheduler, time, logStartOffset, recoveryPoint,
       maxTransactionTimeoutMs, producerStateManagerConfig, producerIdExpirationCheckIntervalMs,
       lastShutdownClean, topicId, keepPartitionMetadataFile, new ConcurrentHashMap[String, Int],
-      remoteStorageSystemEnable, remoteLogManager)
+      remoteStorageSystemEnable, remoteLogManager, logOffsetsListener)
   }
 
   private def createLogWithOffsetOverflow(logConfig: LogConfig): (UnifiedLog, LogSegment) = {
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 42812c9f6a3..5602acc7985 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -26,6 +26,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit}
 import java.util.stream.IntStream
 import java.util.{Collections, Optional, OptionalLong, Properties}
 import kafka.api._
+import kafka.cluster.PartitionTest.MockPartitionListener
 import kafka.cluster.{BrokerEndPoint, Partition}
 import kafka.log._
 import kafka.server.QuotaFactory.{QuotaManagers, UnboundedQuota}
@@ -4137,6 +4138,94 @@ class ReplicaManagerTest {
     TestUtils.assertNoNonDaemonThreads(this.getClass.getName)
   }
 
+  @Test
+  def testPartitionListener(): Unit = {
+    val maxFetchBytes = 1024 * 1024
+    val aliveBrokersIds = Seq(0, 1)
+    val leaderEpoch = 5
+    val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time),
+      brokerId = 0, aliveBrokersIds)
+    try {
+      val tp = new TopicPartition(topic, 0)
+      val tidp = new TopicIdPartition(topicId, tp)
+      val replicas = aliveBrokersIds.toList.map(Int.box).asJava
+
+      val listener = new MockPartitionListener
+      listener.verify()
+
+      // Registering a listener should fail because the partition does not exist yet.
+      assertFalse(replicaManager.maybeAddListener(tp, listener))
+
+      // Broker 0 becomes leader of the partition
+      val leaderAndIsrPartitionState = new LeaderAndIsrPartitionState()
+        .setTopicName(topic)
+        .setPartitionIndex(0)
+        .setControllerEpoch(0)
+        .setLeader(0)
+        .setLeaderEpoch(leaderEpoch)
+        .setIsr(replicas)
+        .setPartitionEpoch(0)
+        .setReplicas(replicas)
+        .setIsNew(true)
+      val leaderAndIsrRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch,
+        Seq(leaderAndIsrPartitionState).asJava,
+        Collections.singletonMap(topic, topicId),
+        Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build()
+      val leaderAndIsrResponse = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest, (_, _) => ())
+      assertEquals(Errors.NONE, leaderAndIsrResponse.error)
+
+      // Registering it should succeed now.
+      assertTrue(replicaManager.maybeAddListener(tp, listener))
+      listener.verify()
+
+      // Leader appends some data
+      for (i <- 1 to 5) {
+        appendRecords(replicaManager, tp, TestUtils.singletonRecords(s"message $i".getBytes)).onFire { response =>
+          assertEquals(Errors.NONE, response.error)
+        }
+      }
+
+      // Follower fetches up to offset 2.
+      fetchPartitionAsFollower(
+        replicaManager,
+        tidp,
+        new FetchRequest.PartitionData(
+          Uuid.ZERO_UUID,
+          2L,
+          0L,
+          maxFetchBytes,
+          Optional.of(leaderEpoch)
+        ),
+        replicaId = 1
+      )
+
+      // Listener is updated.
+      listener.verify(expectedHighWatermark = 2L)
+
+      // Listener is removed.
+      replicaManager.removeListener(tp, listener)
+
+      // Follower fetches up to offset 4.
+      fetchPartitionAsFollower(
+        replicaManager,
+        tidp,
+        new FetchRequest.PartitionData(
+          Uuid.ZERO_UUID,
+          4L,
+          0L,
+          maxFetchBytes,
+          Optional.of(leaderEpoch)
+        ),
+        replicaId = 1
+      )
+
+      // Listener is not updated anymore.
+      listener.verify()
+    } finally {
+      replicaManager.shutdown(checkpointHW = false)
+    }
+  }
+
   private def topicsCreateDelta(startId: Int, isStartIdLeader: Boolean): TopicsDelta = {
     val leader = if (isStartIdLeader) startId else startId + 1
     val delta = new TopicsDelta(TopicsImage.EMPTY)
diff --git a/storage/src/main/java/org/apache/kafka/storage/internals/log/LogOffsetsListener.java b/storage/src/main/java/org/apache/kafka/storage/internals/log/LogOffsetsListener.java
new file mode 100644
index 00000000000..2dcb2bdc9b5
--- /dev/null
+++ b/storage/src/main/java/org/apache/kafka/storage/internals/log/LogOffsetsListener.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.storage.internals.log;
+
+/**
+ * Listener receive notification from the Log.
+ *
+ * Note that the callbacks are executed in the thread that triggers the change
+ * AND that locks may be held during their execution. They are meant to be used
+ * as notification mechanism only.
+ */
+public interface LogOffsetsListener {
+    /**
+     * A default no op offsets listener.
+     */
+    LogOffsetsListener NO_OP_OFFSETS_LISTENER = new LogOffsetsListener() { };
+
+    /**
+     * Called when the Log increments its high watermark.
+     */
+    default void onHighWatermarkUpdated(long offset) {}
+}