You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2020/04/03 19:16:24 UTC

[kafka] branch 2.5 updated: KAFKA-9750; Fix race condition with log dir reassign completion (#8412)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.5
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.5 by this push:
     new d6016b6  KAFKA-9750; Fix race condition with log dir reassign completion (#8412)
d6016b6 is described below

commit d6016b6906ae3df54462aaf2d9b27419495c203d
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Fri Apr 3 11:51:04 2020 -0700

    KAFKA-9750; Fix race condition with log dir reassign completion (#8412)
    
    There is a race on receiving a LeaderAndIsr request for a replica with an active log dir reassignment. If the reassignment completes just before the LeaderAndIsr handler updates epoch information, it can lead to an illegal state error since no future log dir exists. This patch fixes the problem by ensuring that the future log dir exists when the fetcher is started. Removal cannot happen concurrently because it requires access the same partition state lock.
    
    Reviewers: Chia-Ping Tsai <ch...@gmail.com>, David Arthur <mu...@gmail.com>
    
    Co-authored-by: Chia-Ping Tsai <ch...@gmail.com>
---
 .../kafka/server/AbstractFetcherManager.scala      |  14 +-
 .../scala/kafka/server/AbstractFetcherThread.scala |  10 +-
 .../kafka/server/ReplicaAlterLogDirsManager.scala  |  16 ++
 .../kafka/server/ReplicaAlterLogDirsThread.scala   |  22 +-
 .../main/scala/kafka/server/ReplicaManager.scala   |   4 +
 .../kafka/server/AbstractFetcherManagerTest.scala  |   2 +
 .../server/ReplicaAlterLogDirsThreadTest.scala     | 239 +++++++++++++++++++--
 .../unit/kafka/server/ReplicaManagerTest.scala     |  46 ++--
 8 files changed, 311 insertions(+), 42 deletions(-)

diff --git a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
index 04cee8c..38457ae 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
@@ -120,7 +120,8 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
         BrokerAndFetcherId(brokerAndInitialFetchOffset.leader, getFetcherId(topicPartition))
       }
 
-      def addAndStartFetcherThread(brokerAndFetcherId: BrokerAndFetcherId, brokerIdAndFetcherId: BrokerIdAndFetcherId): AbstractFetcherThread = {
+      def addAndStartFetcherThread(brokerAndFetcherId: BrokerAndFetcherId,
+                                   brokerIdAndFetcherId: BrokerIdAndFetcherId): T = {
         val fetcherThread = createFetcherThread(brokerAndFetcherId.fetcherId, brokerAndFetcherId.broker)
         fetcherThreadMap.put(brokerIdAndFetcherId, fetcherThread)
         fetcherThread.start()
@@ -144,14 +145,17 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
           tp -> OffsetAndEpoch(brokerAndInitOffset.initOffset, brokerAndInitOffset.currentLeaderEpoch)
         }
 
-        fetcherThread.addPartitions(initialOffsetAndEpochs)
-        info(s"Added fetcher to broker ${brokerAndFetcherId.broker} for partitions $initialOffsetAndEpochs")
-
-        failedPartitions.removeAll(partitionAndOffsets.keySet)
+        addPartitionsToFetcherThread(fetcherThread, initialOffsetAndEpochs)
       }
     }
   }
 
+  protected def addPartitionsToFetcherThread(fetcherThread: T,
+                                             initialOffsetAndEpochs: collection.Map[TopicPartition, OffsetAndEpoch]): Unit = {
+    fetcherThread.addPartitions(initialOffsetAndEpochs)
+    info(s"Added fetcher to broker ${fetcherThread.sourceBroker.id} for partitions $initialOffsetAndEpochs")
+  }
+
   def removeFetcherForPartitions(partitions: Set[TopicPartition]): Unit = {
     lock synchronized {
       for (fetcher <- fetcherThreadMap.values)
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
index be68074..c7cdedc 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
@@ -62,7 +62,7 @@ abstract class AbstractFetcherThread(name: String,
   type EpochData = OffsetsForLeaderEpochRequest.PartitionData
 
   private val partitionStates = new PartitionStates[PartitionFetchState]
-  private val partitionMapLock = new ReentrantLock
+  protected val partitionMapLock = new ReentrantLock
   private val partitionMapCond = partitionMapLock.newCondition()
 
   private val metricId = ClientIdAndBroker(clientId, sourceBroker.host, sourceBroker.port)
@@ -199,7 +199,7 @@ abstract class AbstractFetcherThread(name: String,
     * - Send OffsetsForLeaderEpochRequest, retrieving the latest offset for each partition's
     *   leader epoch. This is the offset the follower should truncate to ensure
     *   accurate log replication.
-    * - Finally truncate the logs for partitions in the truncating phase and mark them
+    * - Finally truncate the logs for partitions in the truncating phase and mark the
     *   truncation complete. Do this within a lock to ensure no leadership changes can
     *   occur during truncation.
     */
@@ -419,10 +419,11 @@ abstract class AbstractFetcherThread(name: String,
     warn(s"Partition $topicPartition marked as failed")
   }
 
-
-  def addPartitions(initialFetchStates: Map[TopicPartition, OffsetAndEpoch]): Unit = {
+  def addPartitions(initialFetchStates: Map[TopicPartition, OffsetAndEpoch]): Set[TopicPartition] = {
     partitionMapLock.lockInterruptibly()
     try {
+      failedPartitions.removeAll(initialFetchStates.keySet)
+
       initialFetchStates.foreach { case (tp, initialFetchState) =>
         // We can skip the truncation step iff the leader epoch matches the existing epoch
         val currentState = partitionStates.stateValue(tp)
@@ -437,6 +438,7 @@ abstract class AbstractFetcherThread(name: String,
       }
 
       partitionMapCond.signalAll()
+      initialFetchStates.keySet
     } finally partitionMapLock.unlock()
   }
 
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
index 5b6c5b1..85eaeaa 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
@@ -18,6 +18,7 @@
 package kafka.server
 
 import kafka.cluster.BrokerEndPoint
+import org.apache.kafka.common.TopicPartition
 
 class ReplicaAlterLogDirsManager(brokerConfig: KafkaConfig,
                                  replicaManager: ReplicaManager,
@@ -34,6 +35,21 @@ class ReplicaAlterLogDirsManager(brokerConfig: KafkaConfig,
       quotaManager, brokerTopicStats)
   }
 
+  override protected def addPartitionsToFetcherThread(fetcherThread: ReplicaAlterLogDirsThread,
+                                                      initialOffsetAndEpochs: collection.Map[TopicPartition, OffsetAndEpoch]): Unit = {
+    val addedPartitions = fetcherThread.addPartitions(initialOffsetAndEpochs)
+    val (addedInitialOffsets, notAddedInitialOffsets) = initialOffsetAndEpochs.partition { case (tp, _) =>
+      addedPartitions.contains(tp)
+    }
+
+    if (addedInitialOffsets.nonEmpty)
+      info(s"Added log dir fetcher for partitions with initial offsets $addedInitialOffsets")
+
+    if (notAddedInitialOffsets.nonEmpty)
+      info(s"Failed to add log dir fetch for partitions ${notAddedInitialOffsets.keySet} " +
+        s"since the log dir reassignment has already completed")
+  }
+
   def shutdown(): Unit = {
     info("shutting down")
     closeAllFetchers()
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
index 6b3ea85..9119b16 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
@@ -91,7 +91,7 @@ class ReplicaAlterLogDirsThread(name: String,
       Request.FutureLocalReplicaId,
       request.minBytes,
       request.maxBytes,
-      request.version <= 2,
+      false,
       request.fetchData.asScala.toSeq,
       UnboundedQuota,
       processResponseCallback,
@@ -116,7 +116,11 @@ class ReplicaAlterLogDirsThread(name: String,
       throw new IllegalStateException("Offset mismatch for the future replica %s: fetched offset = %d, log end offset = %d.".format(
         topicPartition, fetchOffset, futureLog.logEndOffset))
 
-    val logAppendInfo = partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = true)
+    val logAppendInfo = if (records.sizeInBytes() > 0)
+      partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = true)
+    else
+      None
+
     futureLog.updateHighWatermark(partitionData.highWatermark)
     futureLog.maybeIncrementLogStartOffset(partitionData.logStartOffset)
 
@@ -127,6 +131,20 @@ class ReplicaAlterLogDirsThread(name: String,
     logAppendInfo
   }
 
+  override def addPartitions(initialFetchStates: Map[TopicPartition, OffsetAndEpoch]): Set[TopicPartition] = {
+    partitionMapLock.lockInterruptibly()
+    try {
+      // It is possible that the log dir fetcher completed just before this call, so we
+      // filter only the partitions which still have a future log dir.
+      val filteredFetchStates = initialFetchStates.filter { case (tp, _) =>
+        replicaMgr.futureLogExists(tp)
+      }
+      super.addPartitions(filteredFetchStates)
+    } finally {
+      partitionMapLock.unlock()
+    }
+  }
+
   override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = {
     val partition = replicaMgr.getPartitionOrException(topicPartition, expectLeader = false)
     partition.localLogOrException.logStartOffset
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index e0994fe..e289bb4 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -475,6 +475,10 @@ class ReplicaManager(val config: KafkaConfig,
     getPartitionOrException(topicPartition, expectLeader = false).futureLocalLogOrException
   }
 
+  def futureLogExists(topicPartition: TopicPartition): Boolean = {
+    getPartitionOrException(topicPartition, expectLeader = false).futureLog.isDefined
+  }
+
   def localLog(topicPartition: TopicPartition): Option[Log] = {
     nonOfflinePartition(topicPartition).flatMap(_.log)
   }
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
index ecd92bf..e56ff75 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
@@ -58,6 +58,7 @@ class AbstractFetcherManagerTest {
 
     EasyMock.expect(fetcher.start())
     EasyMock.expect(fetcher.addPartitions(Map(tp -> OffsetAndEpoch(fetchOffset, leaderEpoch))))
+        .andReturn(Set(tp))
     EasyMock.expect(fetcher.fetchState(tp))
       .andReturn(Some(PartitionFetchState(fetchOffset, None, leaderEpoch, Truncating)))
     EasyMock.expect(fetcher.removePartitions(Set(tp)))
@@ -116,6 +117,7 @@ class AbstractFetcherManagerTest {
 
     EasyMock.expect(fetcher.start())
     EasyMock.expect(fetcher.addPartitions(Map(tp -> OffsetAndEpoch(fetchOffset, leaderEpoch))))
+        .andReturn(Set(tp))
     EasyMock.expect(fetcher.isThreadFailed).andReturn(true)
     EasyMock.replay(fetcher)
 
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
index 17fed26..c2f9033 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
@@ -18,22 +18,27 @@ package kafka.server
 
 import java.util.Optional
 
+import kafka.api.Request
 import kafka.cluster.{BrokerEndPoint, Partition}
 import kafka.log.{Log, LogManager}
 import kafka.server.AbstractFetcherThread.ResultWithPartitions
+import kafka.server.QuotaFactory.UnboundedQuota
 import kafka.utils.{DelayedItem, TestUtils}
-import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.KafkaStorageException
 import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.requests.{EpochEndOffset, OffsetsForLeaderEpochRequest}
+import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.requests.EpochEndOffset.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET}
+import org.apache.kafka.common.requests.{EpochEndOffset, FetchRequest, OffsetsForLeaderEpochRequest}
+import org.apache.kafka.common.{IsolationLevel, TopicPartition}
 import org.easymock.EasyMock._
 import org.easymock.{Capture, CaptureType, EasyMock, IExpectationSetters}
 import org.junit.Assert._
 import org.junit.Test
+import org.mockito.Mockito.{doNothing, when}
+import org.mockito.{ArgumentCaptor, ArgumentMatchers, Mockito}
 
-import scala.collection.JavaConverters._
 import scala.collection.{Map, Seq}
+import scala.jdk.CollectionConverters._
 
 class ReplicaAlterLogDirsThreadTest {
 
@@ -46,6 +51,207 @@ class ReplicaAlterLogDirsThreadTest {
   }
 
   @Test
+  def shouldNotAddPartitionIfFutureLogIsNotDefined(): Unit = {
+    val brokerId = 1
+    val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(brokerId, "localhost:1234"))
+
+    val replicaManager = Mockito.mock(classOf[ReplicaManager])
+    val quotaManager = Mockito.mock(classOf[ReplicationQuotaManager])
+
+    when(replicaManager.futureLogExists(t1p0)).thenReturn(false)
+
+    val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val thread = new ReplicaAlterLogDirsThread(
+      "alter-logs-dirs-thread",
+      sourceBroker = endPoint,
+      brokerConfig = config,
+      failedPartitions = failedPartitions,
+      replicaMgr = replicaManager,
+      quota = quotaManager,
+      brokerTopicStats = new BrokerTopicStats)
+
+    val addedPartitions = thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L)))
+    assertEquals(Set.empty, addedPartitions)
+    assertEquals(0, thread.partitionCount())
+    assertEquals(None, thread.fetchState(t1p0))
+  }
+
+  @Test
+  def shouldUpdateLeaderEpochAfterFencedEpochError(): Unit = {
+    val brokerId = 1
+    val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(brokerId, "localhost:1234"))
+
+    val partition = Mockito.mock(classOf[Partition])
+    val replicaManager = Mockito.mock(classOf[ReplicaManager])
+    val quotaManager = Mockito.mock(classOf[ReplicationQuotaManager])
+    val futureLog = Mockito.mock(classOf[Log])
+
+    val leaderEpoch = 5
+    val logEndOffset = 0
+
+    when(replicaManager.futureLocalLogOrException(t1p0)).thenReturn(futureLog)
+    when(replicaManager.futureLogExists(t1p0)).thenReturn(true)
+    when(replicaManager.nonOfflinePartition(t1p0)).thenReturn(Some(partition))
+    when(replicaManager.getPartitionOrException(t1p0, expectLeader = false)).thenReturn(partition)
+
+    when(quotaManager.isQuotaExceeded).thenReturn(false)
+
+    when(partition.lastOffsetForLeaderEpoch(Optional.empty(), leaderEpoch, fetchOnlyFromLeader = false))
+      .thenReturn(new EpochEndOffset(leaderEpoch, logEndOffset))
+    when(partition.futureLocalLogOrException).thenReturn(futureLog)
+    doNothing().when(partition).truncateTo(offset = 0, isFuture = true)
+    when(partition.maybeReplaceCurrentWithFutureReplica()).thenReturn(true)
+
+    when(futureLog.logStartOffset).thenReturn(0L)
+    when(futureLog.logEndOffset).thenReturn(0L)
+    when(futureLog.latestEpoch).thenReturn(None)
+
+    val fencedRequestData = new FetchRequest.PartitionData(0L, 0L,
+      config.replicaFetchMaxBytes, Optional.of(leaderEpoch - 1))
+    val fencedResponseData = FetchPartitionData(
+      error = Errors.FENCED_LEADER_EPOCH,
+      highWatermark = -1,
+      logStartOffset = -1,
+      records = MemoryRecords.EMPTY,
+      lastStableOffset = None,
+      abortedTransactions = None,
+      preferredReadReplica = None,
+      isReassignmentFetch = false)
+    mockFetchFromCurrentLog(t1p0, fencedRequestData, config, replicaManager, fencedResponseData)
+
+    val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val thread = new ReplicaAlterLogDirsThread(
+      "alter-logs-dirs-thread",
+      sourceBroker = endPoint,
+      brokerConfig = config,
+      failedPartitions = failedPartitions,
+      replicaMgr = replicaManager,
+      quota = quotaManager,
+      brokerTopicStats = new BrokerTopicStats)
+
+    // Initially we add the partition with an older epoch which results in an error
+    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(fetchOffset = 0L, leaderEpoch - 1)))
+    assertTrue(thread.fetchState(t1p0).isDefined)
+    assertEquals(1, thread.partitionCount())
+
+    thread.doWork()
+
+    assertTrue(failedPartitions.contains(t1p0))
+    assertEquals(None, thread.fetchState(t1p0))
+    assertEquals(0, thread.partitionCount())
+
+    // Next we update the epoch and assert that we can continue
+    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(fetchOffset = 0L, leaderEpoch)))
+    assertEquals(Some(leaderEpoch), thread.fetchState(t1p0).map(_.currentLeaderEpoch))
+    assertEquals(1, thread.partitionCount())
+
+    val requestData = new FetchRequest.PartitionData(0L, 0L,
+      config.replicaFetchMaxBytes, Optional.of(leaderEpoch))
+    val responseData = FetchPartitionData(
+      error = Errors.NONE,
+      highWatermark = 0L,
+      logStartOffset = 0L,
+      records = MemoryRecords.EMPTY,
+      lastStableOffset = None,
+      abortedTransactions = None,
+      preferredReadReplica = None,
+      isReassignmentFetch = false)
+    mockFetchFromCurrentLog(t1p0, requestData, config, replicaManager, responseData)
+
+    thread.doWork()
+
+    assertFalse(failedPartitions.contains(t1p0))
+    assertEquals(None, thread.fetchState(t1p0))
+    assertEquals(0, thread.partitionCount())
+  }
+
+  @Test
+  def shouldReplaceCurrentLogDirWhenCaughtUp(): Unit = {
+    val brokerId = 1
+    val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(brokerId, "localhost:1234"))
+
+    val partition = Mockito.mock(classOf[Partition])
+    val replicaManager = Mockito.mock(classOf[ReplicaManager])
+    val quotaManager = Mockito.mock(classOf[ReplicationQuotaManager])
+    val futureLog = Mockito.mock(classOf[Log])
+
+    val leaderEpoch = 5
+    val logEndOffset = 0
+
+    when(replicaManager.futureLocalLogOrException(t1p0)).thenReturn(futureLog)
+    when(replicaManager.futureLogExists(t1p0)).thenReturn(true)
+    when(replicaManager.nonOfflinePartition(t1p0)).thenReturn(Some(partition))
+    when(replicaManager.getPartitionOrException(t1p0, expectLeader = false)).thenReturn(partition)
+
+    when(quotaManager.isQuotaExceeded).thenReturn(false)
+
+    when(partition.lastOffsetForLeaderEpoch(Optional.empty(), leaderEpoch, fetchOnlyFromLeader = false))
+      .thenReturn(new EpochEndOffset(leaderEpoch, logEndOffset))
+    when(partition.futureLocalLogOrException).thenReturn(futureLog)
+    doNothing().when(partition).truncateTo(offset = 0, isFuture = true)
+    when(partition.maybeReplaceCurrentWithFutureReplica()).thenReturn(true)
+
+    when(futureLog.logStartOffset).thenReturn(0L)
+    when(futureLog.logEndOffset).thenReturn(0L)
+    when(futureLog.latestEpoch).thenReturn(None)
+
+    val requestData = new FetchRequest.PartitionData(0L, 0L,
+      config.replicaFetchMaxBytes, Optional.of(leaderEpoch))
+    val responseData = FetchPartitionData(
+      error = Errors.NONE,
+      highWatermark = 0L,
+      logStartOffset = 0L,
+      records = MemoryRecords.EMPTY,
+      lastStableOffset = None,
+      abortedTransactions = None,
+      preferredReadReplica = None,
+      isReassignmentFetch = false)
+    mockFetchFromCurrentLog(t1p0, requestData, config, replicaManager, responseData)
+
+    val endPoint = new BrokerEndPoint(0, "localhost", 1000)
+    val thread = new ReplicaAlterLogDirsThread(
+      "alter-logs-dirs-thread",
+      sourceBroker = endPoint,
+      brokerConfig = config,
+      failedPartitions = failedPartitions,
+      replicaMgr = replicaManager,
+      quota = quotaManager,
+      brokerTopicStats = new BrokerTopicStats)
+
+    thread.addPartitions(Map(t1p0 -> offsetAndEpoch(fetchOffset = 0L, leaderEpoch)))
+    assertTrue(thread.fetchState(t1p0).isDefined)
+    assertEquals(1, thread.partitionCount())
+
+    thread.doWork()
+
+    assertEquals(None, thread.fetchState(t1p0))
+    assertEquals(0, thread.partitionCount())
+  }
+
+  private def mockFetchFromCurrentLog(topicPartition: TopicPartition,
+                                      requestData: FetchRequest.PartitionData,
+                                      config: KafkaConfig,
+                                      replicaManager: ReplicaManager,
+                                      responseData: FetchPartitionData): Unit = {
+    val callbackCaptor: ArgumentCaptor[Seq[(TopicPartition, FetchPartitionData)] => Unit] =
+      ArgumentCaptor.forClass(classOf[Seq[(TopicPartition, FetchPartitionData)] => Unit])
+    when(replicaManager.fetchMessages(
+      timeout = ArgumentMatchers.eq(0L),
+      replicaId = ArgumentMatchers.eq(Request.FutureLocalReplicaId),
+      fetchMinBytes = ArgumentMatchers.eq(0),
+      fetchMaxBytes = ArgumentMatchers.eq(config.replicaFetchResponseMaxBytes),
+      hardMaxBytesLimit = ArgumentMatchers.eq(false),
+      fetchInfos = ArgumentMatchers.eq(Seq(topicPartition -> requestData)),
+      quota = ArgumentMatchers.eq(UnboundedQuota),
+      responseCallback = callbackCaptor.capture(),
+      isolationLevel = ArgumentMatchers.eq(IsolationLevel.READ_UNCOMMITTED),
+      clientMetadata = ArgumentMatchers.eq(None)
+    )).thenAnswer(_ => {
+      callbackCaptor.getValue.apply(Seq((topicPartition, responseData)))
+    })
+  }
+
+  @Test
   def issuesEpochRequestFromLocalReplica(): Unit = {
     val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234"))
 
@@ -79,7 +285,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
-      failedPartitions : FailedPartitions,
+      failedPartitions = failedPartitions,
       replicaMgr = replicaManager,
       quota = null,
       brokerTopicStats = null)
@@ -125,7 +331,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
-      failedPartitions: FailedPartitions,
+      failedPartitions = failedPartitions,
       replicaMgr = replicaManager,
       quota = null,
       brokerTopicStats = null)
@@ -174,7 +380,9 @@ class ReplicaAlterLogDirsThreadTest {
     expect(replicaManager.getPartitionOrException(t1p1, expectLeader = false))
       .andStubReturn(partitionT1p1)
     expect(replicaManager.futureLocalLogOrException(t1p0)).andStubReturn(futureLogT1p0)
+    expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true)
     expect(replicaManager.futureLocalLogOrException(t1p1)).andStubReturn(futureLogT1p1)
+    expect(replicaManager.futureLogExists(t1p1)).andStubReturn(true)
     expect(partitionT1p0.truncateTo(capture(truncateCaptureT1p0), anyBoolean())).anyTimes()
     expect(partitionT1p1.truncateTo(capture(truncateCaptureT1p1), anyBoolean())).anyTimes()
 
@@ -206,7 +414,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
-      failedPartitions: FailedPartitions,
+      failedPartitions = failedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -247,6 +455,7 @@ class ReplicaAlterLogDirsThreadTest {
     expect(replicaManager.getPartitionOrException(t1p0, expectLeader = false))
       .andStubReturn(partition)
     expect(replicaManager.futureLocalLogOrException(t1p0)).andStubReturn(futureLog)
+    expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true)
 
     expect(partition.truncateTo(capture(truncateToCapture), EasyMock.eq(true))).anyTimes()
     expect(futureLog.logEndOffset).andReturn(futureReplicaLEO).anyTimes()
@@ -278,7 +487,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
-      failedPartitions : FailedPartitions,
+      failedPartitions = failedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -317,6 +526,7 @@ class ReplicaAlterLogDirsThreadTest {
       .andStubReturn(partition)
     expect(partition.truncateTo(capture(truncated), isFuture = EasyMock.eq(true))).anyTimes()
     expect(replicaManager.futureLocalLogOrException(t1p0)).andStubReturn(futureLog)
+    expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true)
 
     expect(replicaManager.logManager).andReturn(logManager).anyTimes()
 
@@ -332,7 +542,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
-      failedPartitions: FailedPartitions,
+      failedPartitions = failedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -372,12 +582,12 @@ class ReplicaAlterLogDirsThreadTest {
     expect(partition.truncateTo(capture(truncated), isFuture = EasyMock.eq(true))).once()
 
     expect(replicaManager.futureLocalLogOrException(t1p0)).andStubReturn(futureLog)
+    expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true)
     expect(futureLog.logEndOffset).andReturn(futureReplicaLEO).anyTimes()
     expect(futureLog.latestEpoch).andStubReturn(Some(futureReplicaLeaderEpoch))
     expect(futureLog.endOffsetForEpoch(futureReplicaLeaderEpoch)).andReturn(
       Some(OffsetAndEpoch(futureReplicaLEO, futureReplicaLeaderEpoch)))
     expect(replicaManager.localLog(t1p0)).andReturn(Some(log)).anyTimes()
-    expect(replicaManager.futureLocalLogOrException(t1p0)).andReturn(futureLog).anyTimes()
 
     // this will cause fetchEpochsFromLeader return an error with undefined offset
     expect(partition.lastOffsetForLeaderEpoch(Optional.of(1), futureReplicaLeaderEpoch, fetchOnlyFromLeader = false))
@@ -407,7 +617,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
-      failedPartitions: FailedPartitions,
+      failedPartitions = failedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -452,6 +662,7 @@ class ReplicaAlterLogDirsThreadTest {
     expect(partition.truncateTo(futureReplicaLEO, isFuture = true)).once()
 
     expect(replicaManager.futureLocalLogOrException(t1p0)).andStubReturn(futureLog)
+    expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true)
     expect(futureLog.latestEpoch).andStubReturn(Some(leaderEpoch))
     expect(futureLog.logEndOffset).andStubReturn(futureReplicaLEO)
     expect(futureLog.endOffsetForEpoch(leaderEpoch)).andReturn(
@@ -467,7 +678,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
-      failedPartitions: FailedPartitions,
+      failedPartitions = failedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -507,7 +718,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
-      failedPartitions: FailedPartitions,
+      failedPartitions = failedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -558,7 +769,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
-      failedPartitions: FailedPartitions,
+      failedPartitions = failedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -607,10 +818,12 @@ class ReplicaAlterLogDirsThreadTest {
     expect(replicaManager.localLog(t1p0)).andReturn(Some(logT1p0)).anyTimes()
     expect(replicaManager.localLogOrException(t1p0)).andReturn(logT1p0).anyTimes()
     expect(replicaManager.futureLocalLogOrException(t1p0)).andReturn(futureLog).anyTimes()
+    expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true)
     expect(replicaManager.nonOfflinePartition(t1p0)).andReturn(Some(partition)).anyTimes()
     expect(replicaManager.localLog(t1p1)).andReturn(Some(logT1p1)).anyTimes()
     expect(replicaManager.localLogOrException(t1p1)).andReturn(logT1p1).anyTimes()
     expect(replicaManager.futureLocalLogOrException(t1p1)).andReturn(futureLog).anyTimes()
+    expect(replicaManager.futureLogExists(t1p1)).andStubReturn(true)
     expect(replicaManager.nonOfflinePartition(t1p1)).andReturn(Some(partition)).anyTimes()
   }
 
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index e20ed10..20609f5 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -211,13 +211,19 @@ class ReplicaManagerTest {
 
   @Test
   def testFencedErrorCausedByBecomeLeader(): Unit = {
+    testFencedErrorCausedByBecomeLeader(0)
+    testFencedErrorCausedByBecomeLeader(1)
+    testFencedErrorCausedByBecomeLeader(10)
+  }
+
+  private[this] def testFencedErrorCausedByBecomeLeader(loopEpochChange: Int): Unit = {
     val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer)
     try {
       val brokerList = Seq[Integer](0, 1).asJava
       val topicPartition = new TopicPartition(topic, 0)
       replicaManager.createPartition(topicPartition)
         .createLogIfNotExists(0, isNew = false, isFutureReplica = false,
-        new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints))
+          new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints))
 
       def leaderAndIsrRequest(epoch: Int): LeaderAndIsrRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch,
         Seq(new LeaderAndIsrPartitionState()
@@ -234,30 +240,34 @@ class ReplicaManagerTest {
 
       replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(0), (_, _) => ())
       val partition = replicaManager.getPartitionOrException(new TopicPartition(topic, 0), expectLeader = true)
-        .localLogOrException
-      assertEquals(1, replicaManager.logManager.liveLogDirs.filterNot(_ == partition.dir.getParentFile).size)
+      assertEquals(1, replicaManager.logManager.liveLogDirs.filterNot(_ == partition.log.get.dir.getParentFile).size)
 
+      val previousReplicaFolder = partition.log.get.dir.getParentFile
       // find the live and different folder
-      val newReplicaFolder = replicaManager.logManager.liveLogDirs.filterNot(_ == partition.dir.getParentFile).head
+      val newReplicaFolder = replicaManager.logManager.liveLogDirs.filterNot(_ == partition.log.get.dir.getParentFile).head
       assertEquals(0, replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.size)
       replicaManager.alterReplicaLogDirs(Map(topicPartition -> newReplicaFolder.getAbsolutePath))
+      // make sure the future log is created
       replicaManager.futureLocalLogOrException(topicPartition)
       assertEquals(1, replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.size)
-      // change the epoch from 0 to 1 in order to make fenced error
-      replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(1), (_, _) => ())
-      TestUtils.waitUntilTrue(() => replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.values.forall(_.partitionCount() == 0),
-        s"the partition=$topicPartition should be removed from pending state")
-      // the partition is added to failedPartitions if fenced error happens
-      // if the thread is done before ReplicaManager#becomeLeaderOrFollower updates epoch,the fenced error does
-      // not happen and failedPartitions is empty.
-      if (replicaManager.replicaAlterLogDirsManager.failedPartitions.size != 0) {
+      (1 to loopEpochChange).foreach(epoch => replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(epoch), (_, _) => ()))
+      // wait for the ReplicaAlterLogDirsThread to complete
+      TestUtils.waitUntilTrue(() => {
         replicaManager.replicaAlterLogDirsManager.shutdownIdleFetcherThreads()
-        assertEquals(0, replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.size)
-        // send request again
-        replicaManager.alterReplicaLogDirs(Map(topicPartition -> newReplicaFolder.getAbsolutePath))
-        // the future folder exists so it fails to invoke thread
-        assertEquals(1, replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.size)
-      }
+        replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.isEmpty
+      }, s"ReplicaAlterLogDirsThread should be gone")
+
+      // the fenced error should be recoverable
+      assertEquals(0, replicaManager.replicaAlterLogDirsManager.failedPartitions.size)
+      // the replica change is completed after retrying
+      assertTrue(partition.futureLog.isEmpty)
+      assertEquals(newReplicaFolder.getAbsolutePath, partition.log.get.dir.getParent)
+      // change the replica folder again
+      val response = replicaManager.alterReplicaLogDirs(Map(topicPartition -> previousReplicaFolder.getAbsolutePath))
+      assertNotEquals(0, response.size)
+      response.values.foreach(assertEquals(Errors.NONE, _))
+      // should succeed to invoke ReplicaAlterLogDirsThread again
+      assertEquals(1, replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.size)
     } finally replicaManager.shutdown(checkpointHW = false)
   }