You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/05/17 22:39:34 UTC

[kafka] branch trunk updated: KAFKA-8346; Improve replica fetcher behavior for handling partition failure [KIP-461] (#6716)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 414852c  KAFKA-8346; Improve replica fetcher behavior for handling partition failure [KIP-461] (#6716)
414852c is described below

commit 414852c701763b6f8362b44d156753b6c3ef247a
Author: Aishwarya Gune <ai...@gmail.com>
AuthorDate: Fri May 17 15:39:20 2019 -0700

    KAFKA-8346; Improve replica fetcher behavior for handling partition failure [KIP-461] (#6716)
    
    The replica fetcher thread is terminated in case a partition crashes which leads to under replication. This behavior can be improved by dropping the failed partition. The thread can continue monitoring the rest of the partitions. If all partitions of a thread have failed, the thread would be shut down. This is documented in KIP-461: https://cwiki.apache.org/confluence/display/KAFKA/KIP-461+-+Improve+Replica+Fetcher+behavior+at+handling+partition+failure.
    
    Reviewers: Jun Rao <ju...@gmail.com>, Jason Gustafson <ja...@confluent.io>
---
 .../kafka/server/AbstractFetcherManager.scala      | 48 +++++++++++-
 .../scala/kafka/server/AbstractFetcherThread.scala | 88 +++++++++++++---------
 .../kafka/server/ReplicaAlterLogDirsManager.scala  |  2 +-
 .../kafka/server/ReplicaAlterLogDirsThread.scala   |  2 +
 .../scala/kafka/server/ReplicaFetcherManager.scala |  2 +-
 .../scala/kafka/server/ReplicaFetcherThread.scala  |  2 +
 .../kafka/server/AbstractFetcherManagerTest.scala  | 41 +++++++++-
 .../kafka/server/AbstractFetcherThreadTest.scala   | 78 ++++++++++++++++++-
 .../server/ReplicaAlterLogDirsThreadTest.scala     | 10 +++
 .../kafka/server/ReplicaFetcherThreadTest.scala    | 24 +++---
 .../unit/kafka/server/ReplicaManagerTest.scala     |  2 +-
 11 files changed, 244 insertions(+), 55 deletions(-)

diff --git a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
index 10ae8df..a5faf0e 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
@@ -34,13 +34,14 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
   private[server] val fetcherThreadMap = new mutable.HashMap[BrokerIdAndFetcherId, T]
   private val lock = new Object
   private var numFetchersPerBroker = numFetchers
+  val failedPartitions = new FailedPartitions
   this.logIdent = "[" + name + "] "
 
   newGauge(
     "MaxLag",
     new Gauge[Long] {
       // current max lag across all fetchers/topics/partitions
-      def value = fetcherThreadMap.foldLeft(0L)((curMaxAll, fetcherThreadMapEntry) => {
+      def value: Long = fetcherThreadMap.foldLeft(0L)((curMaxAll, fetcherThreadMapEntry) => {
         fetcherThreadMapEntry._2.fetcherLagStats.stats.foldLeft(0L)((curMaxThread, fetcherLagStatsEntry) => {
           curMaxThread.max(fetcherLagStatsEntry._2.lag)
         }).max(curMaxAll)
@@ -53,7 +54,7 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
   "MinFetchRate", {
     new Gauge[Double] {
       // current min fetch rate across all fetchers/topics/partitions
-      def value = {
+      def value: Double = {
         val headRate: Double =
           fetcherThreadMap.headOption.map(_._2.fetcherStats.requestRate.oneMinuteRate).getOrElse(0)
 
@@ -66,6 +67,15 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
   Map("clientId" -> clientId)
   )
 
+  val failedPartitionsCount = newGauge(
+    "FailedPartitionsCount", {
+      new Gauge[Int] {
+        def value: Int = failedPartitions.size
+      }
+    },
+    Map("clientId" -> clientId)
+  )
+
   def resizeThreadPool(newSize: Int): Unit = {
     def migratePartitions(newSize: Int): Unit = {
       fetcherThreadMap.foreach { case (id, thread) =>
@@ -152,6 +162,8 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
 
         fetcherThread.addPartitions(initialOffsetAndEpochs)
         info(s"Added fetcher to broker ${brokerAndFetcherId.broker} for partitions $initialOffsetAndEpochs")
+
+        failedPartitions.removeAll(partitionAndOffsets.keySet)
       }
     }
   }
@@ -160,6 +172,7 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
     lock synchronized {
       for (fetcher <- fetcherThreadMap.values)
         fetcher.removePartitions(partitions)
+      failedPartitions.removeAll(partitions)
     }
     info(s"Removed fetcher for partitions $partitions")
   }
@@ -191,6 +204,37 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
   }
 }
 
+/**
+  * The class FailedPartitions would keep a track of partitions marked as failed either during truncation or appending
+  * resulting from one of the following errors -
+  * <ol>
+  *   <li> Storage exception
+  *   <li> Fenced epoch
+  *   <li> Unexpected errors
+  * </ol>
+  * The partitions which fail due to storage error are eventually removed from this set after the log directory is
+  * taken offline.
+  */
+class FailedPartitions {
+  private val failedPartitionsSet = new mutable.HashSet[TopicPartition]
+
+  def size: Int = synchronized {
+    failedPartitionsSet.size
+  }
+
+  def add(topicPartition: TopicPartition): Unit = synchronized {
+    failedPartitionsSet += topicPartition
+  }
+
+  def removeAll(topicPartitions: Set[TopicPartition]): Unit = synchronized {
+    failedPartitionsSet --= topicPartitions
+  }
+
+  def contains(topicPartition: TopicPartition): Boolean = synchronized {
+    failedPartitionsSet.contains(topicPartition)
+  }
+}
+
 case class BrokerAndFetcherId(broker: BrokerEndPoint, fetcherId: Int)
 
 case class InitialFetchState(leader: BrokerEndPoint, currentLeaderEpoch: Int, initOffset: Long)
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
index 3cc6137..203cc62 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
@@ -39,7 +39,7 @@ import java.util.function.Consumer
 
 import com.yammer.metrics.core.Gauge
 import kafka.log.LogAppendInfo
-import org.apache.kafka.common.{KafkaException, TopicPartition}
+import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.internals.PartitionStates
 import org.apache.kafka.common.record.{FileRecords, MemoryRecords, Records}
 import org.apache.kafka.common.requests._
@@ -52,6 +52,7 @@ import scala.math._
 abstract class AbstractFetcherThread(name: String,
                                      clientId: String,
                                      val sourceBroker: BrokerEndPoint,
+                                     failedPartitions: FailedPartitions,
                                      fetchBackOffMs: Int = 0,
                                      isInterruptible: Boolean = true)
   extends ShutdownableThread(name, isInterruptible) {
@@ -135,9 +136,10 @@ abstract class AbstractFetcherThread(name: String,
 
   // deal with partitions with errors, potentially due to leadership changes
   private def handlePartitionsWithErrors(partitions: Iterable[TopicPartition], methodName: String) {
-    if (partitions.nonEmpty)
+    if (partitions.nonEmpty) {
       debug(s"Handling errors in $methodName for partitions $partitions")
       delayPartitions(partitions, fetchBackOffMs)
+    }
   }
 
   /**
@@ -175,6 +177,24 @@ abstract class AbstractFetcherThread(name: String,
     }
   }
 
+  private def doTruncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Boolean = {
+    try {
+      truncate(topicPartition, truncationState)
+      true
+    }
+    catch {
+      case e: KafkaStorageException =>
+        error(s"Failed to truncate $topicPartition at offset ${truncationState.offset}", e)
+        markPartitionFailed(topicPartition)
+        false
+      case t: Throwable =>
+        error(s"Unexpected error occurred during truncation for $topicPartition "
+          + s"at offset ${truncationState.offset}", t)
+        markPartitionFailed(topicPartition)
+        false
+    }
+  }
+
   /**
     * - Build a leader epoch fetch based on partitions that are in the Truncating phase
     * - Send OffsetsForLeaderEpochRequest, retrieving the latest offset for each partition's
@@ -208,28 +228,19 @@ abstract class AbstractFetcherThread(name: String,
   // Visible for testing
   private[server] def truncateToHighWatermark(partitions: Set[TopicPartition]): Unit = inLock(partitionMapLock) {
     val fetchOffsets = mutable.HashMap.empty[TopicPartition, OffsetTruncationState]
-    val partitionsWithError = mutable.HashSet.empty[TopicPartition]
 
     for (tp <- partitions) {
       val partitionState = partitionStates.stateValue(tp)
       if (partitionState != null) {
-        try {
-          val highWatermark = partitionState.fetchOffset
-          val truncationState = OffsetTruncationState(highWatermark, truncationCompleted = true)
-
-          info(s"Truncating partition $tp to local high watermark $highWatermark")
-          truncate(tp, truncationState)
+        val highWatermark = partitionState.fetchOffset
+        val truncationState = OffsetTruncationState(highWatermark, truncationCompleted = true)
 
+        info(s"Truncating partition $tp to local high watermark $highWatermark")
+        if (doTruncate(tp, truncationState))
           fetchOffsets.put(tp, truncationState)
-        } catch {
-          case e: KafkaStorageException =>
-            info(s"Failed to truncate $tp", e)
-            partitionsWithError += tp
-        }
       }
     }
 
-    handlePartitionsWithErrors(partitionsWithError, "truncateToHighWatermark")
     updateFetchOffsetAndMaybeMarkTruncationComplete(fetchOffsets)
   }
 
@@ -238,23 +249,17 @@ abstract class AbstractFetcherThread(name: String,
     val partitionsWithError = mutable.HashSet.empty[TopicPartition]
 
     fetchedEpochs.foreach { case (tp, leaderEpochOffset) =>
-      try {
-        leaderEpochOffset.error match {
-          case Errors.NONE =>
-            val offsetTruncationState = getOffsetTruncationState(tp, leaderEpochOffset)
-            truncate(tp, offsetTruncationState)
+      leaderEpochOffset.error match {
+        case Errors.NONE =>
+          val offsetTruncationState = getOffsetTruncationState(tp, leaderEpochOffset)
+          if(doTruncate(tp, offsetTruncationState))
             fetchOffsets.put(tp, offsetTruncationState)
 
-          case Errors.FENCED_LEADER_EPOCH =>
-            onPartitionFenced(tp)
+        case Errors.FENCED_LEADER_EPOCH =>
+          onPartitionFenced(tp)
 
-          case error =>
-            info(s"Retrying leaderEpoch request for partition $tp as the leader reported an error: $error")
-            partitionsWithError += tp
-        }
-      } catch {
-        case e: KafkaStorageException =>
-          info(s"Failed to truncate $tp", e)
+        case error =>
+          info(s"Retrying leaderEpoch request for partition $tp as the leader reported an error: $error")
           partitionsWithError += tp
       }
     }
@@ -267,7 +272,7 @@ abstract class AbstractFetcherThread(name: String,
       val currentLeaderEpoch = currentFetchState.currentLeaderEpoch
       info(s"Partition $tp has an older epoch ($currentLeaderEpoch) than the current leader. Will await " +
         s"the new LeaderAndIsr state before resuming fetching.")
-      partitionStates.remove(tp)
+      markPartitionFailed(tp)
     }
   }
 
@@ -336,11 +341,14 @@ abstract class AbstractFetcherThread(name: String,
                         s"offset ${currentFetchState.fetchOffset}", ime)
                       partitionsWithError += topicPartition
                     case e: KafkaStorageException =>
-                      error(s"Error while processing data for partition $topicPartition", e)
-                      partitionsWithError += topicPartition
-                    case e: Throwable =>
-                      throw new KafkaException(s"Error processing data for partition $topicPartition " +
-                        s"offset ${currentFetchState.fetchOffset}", e)
+                      error(s"Error while processing data for partition $topicPartition " +
+                        s"at offset ${currentFetchState.fetchOffset}", e)
+                      markPartitionFailed(topicPartition)
+                    case t: Throwable =>
+                      // stop monitoring this partition and add it to the set of failed partitions
+                      error(s"Unexpected error occurred while processing data for partition $topicPartition " +
+                        s"at offset ${currentFetchState.fetchOffset}", t)
+                      markPartitionFailed(topicPartition)
                   }
                 case Errors.OFFSET_OUT_OF_RANGE =>
                   if (!handleOutOfRangeError(topicPartition, currentFetchState))
@@ -387,6 +395,16 @@ abstract class AbstractFetcherThread(name: String,
     } finally partitionMapLock.unlock()
   }
 
+  private def markPartitionFailed(topicPartition: TopicPartition): Unit = {
+    partitionMapLock.lock()
+    try {
+      failedPartitions.add(topicPartition)
+      removePartitions(Set(topicPartition))
+    } finally partitionMapLock.unlock()
+    warn(s"Partition $topicPartition marked as failed")
+  }
+
+
   def addPartitions(initialFetchStates: Map[TopicPartition, OffsetAndEpoch]) {
     partitionMapLock.lockInterruptibly()
     try {
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
index 1616b84..d5ab4d6 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala
@@ -30,7 +30,7 @@ class ReplicaAlterLogDirsManager(brokerConfig: KafkaConfig,
 
   override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): ReplicaAlterLogDirsThread = {
     val threadName = s"ReplicaAlterLogDirsThread-$fetcherId"
-    new ReplicaAlterLogDirsThread(threadName, sourceBroker, brokerConfig, replicaManager,
+    new ReplicaAlterLogDirsThread(threadName, sourceBroker, brokerConfig, failedPartitions, replicaManager,
       quotaManager, brokerTopicStats)
   }
 
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
index a1f0134..4312a92 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
@@ -39,12 +39,14 @@ import scala.collection.{Map, Seq, Set, mutable}
 class ReplicaAlterLogDirsThread(name: String,
                                 sourceBroker: BrokerEndPoint,
                                 brokerConfig: KafkaConfig,
+                                failedPartitions: FailedPartitions,
                                 replicaMgr: ReplicaManager,
                                 quota: ReplicationQuotaManager,
                                 brokerTopicStats: BrokerTopicStats)
   extends AbstractFetcherThread(name = name,
                                 clientId = name,
                                 sourceBroker = sourceBroker,
+                                failedPartitions,
                                 fetchBackOffMs = brokerConfig.replicaFetchBackoffMs,
                                 isInterruptible = false) {
 
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala b/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala
index fa902b9..3426290 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala
@@ -35,7 +35,7 @@ class ReplicaFetcherManager(brokerConfig: KafkaConfig,
   override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): ReplicaFetcherThread = {
     val prefix = threadNamePrefix.map(tp => s"$tp:").getOrElse("")
     val threadName = s"${prefix}ReplicaFetcherThread-$fetcherId-${sourceBroker.id}"
-    new ReplicaFetcherThread(threadName, fetcherId, sourceBroker, brokerConfig, replicaManager,
+    new ReplicaFetcherThread(threadName, fetcherId, sourceBroker, brokerConfig, failedPartitions, replicaManager,
       metrics, time, quotaManager)
   }
 
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index 3e22653..82f51e6 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -40,6 +40,7 @@ class ReplicaFetcherThread(name: String,
                            fetcherId: Int,
                            sourceBroker: BrokerEndPoint,
                            brokerConfig: KafkaConfig,
+                           failedPartitions: FailedPartitions,
                            replicaMgr: ReplicaManager,
                            metrics: Metrics,
                            time: Time,
@@ -48,6 +49,7 @@ class ReplicaFetcherThread(name: String,
   extends AbstractFetcherThread(name = name,
                                 clientId = name,
                                 sourceBroker = sourceBroker,
+                                failedPartitions,
                                 fetchBackOffMs = brokerConfig.replicaFetchBackoffMs,
                                 isInterruptible = false) {
 
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
index 0a4d7c1..ec57a0f 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala
@@ -16,14 +16,29 @@
  */
 package kafka.server
 
+import com.yammer.metrics.Metrics
+import com.yammer.metrics.core.Gauge
 import kafka.cluster.BrokerEndPoint
 import org.apache.kafka.common.TopicPartition
 import org.easymock.EasyMock
-import org.junit.Test
+import org.junit.{Before, Test}
 import org.junit.Assert._
 
+import scala.collection.JavaConverters._
+
 class AbstractFetcherManagerTest {
 
+  @Before
+  def cleanMetricRegistry(): Unit = {
+    for (metricName <- Metrics.defaultRegistry().allMetrics().keySet().asScala)
+      Metrics.defaultRegistry().removeMetric(metricName)
+  }
+
+  private def getMetricValue(name: String): Any = {
+    Metrics.defaultRegistry.allMetrics.asScala.filterKeys(_.getName == name).values.headOption.get.
+      asInstanceOf[Gauge[Int]].value()
+  }
+
   @Test
   def testAddAndRemovePartition(): Unit = {
     val fetcher: AbstractFetcherThread = EasyMock.mock(classOf[AbstractFetcherThread])
@@ -58,4 +73,28 @@ class AbstractFetcherManagerTest {
     EasyMock.verify(fetcher)
   }
 
+  @Test
+  def testMetricFailedPartitionCount(): Unit = {
+    val fetcher: AbstractFetcherThread = EasyMock.mock(classOf[AbstractFetcherThread])
+    val fetcherManager = new AbstractFetcherManager[AbstractFetcherThread]("fetcher-manager", "fetcher-manager", 2) {
+      override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): AbstractFetcherThread = {
+        fetcher
+      }
+    }
+
+    val tp = new TopicPartition("topic", 0)
+    val metricName = "FailedPartitionsCount"
+
+    // initial value for failed partition count
+    assertEquals(0, getMetricValue(metricName))
+
+    // partition marked as failed increments the count for failed partitions
+    fetcherManager.failedPartitions.add(tp)
+    assertEquals(1, getMetricValue(metricName))
+
+    // removing fetcher for the partition would remove the partition from set of failed partitions and decrement the
+    // count for failed partitions
+    fetcherManager.removeFetcherForPartitions(Set(tp))
+    assertEquals(0, getMetricValue(metricName))
+  }
 }
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
index 1fc079d..1d8249a 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
@@ -27,6 +27,7 @@ import kafka.log.LogAppendInfo
 import kafka.message.NoCompressionCodec
 import kafka.server.AbstractFetcherThread.ResultWithPartitions
 import kafka.utils.TestUtils
+import org.apache.kafka.common.KafkaException
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.{FencedLeaderEpochException, UnknownLeaderEpochException}
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
@@ -45,6 +46,10 @@ import scala.collection.mutable.ArrayBuffer
 
 class AbstractFetcherThreadTest {
 
+  private val partition1 = new TopicPartition("topic1", 0)
+  private val partition2 = new TopicPartition("topic2", 0)
+  private val failedPartitions = new FailedPartitions
+
   @Before
   def cleanMetricRegistry(): Unit = {
     for (metricName <- Metrics.defaultRegistry().allMetrics().keySet().asScala)
@@ -147,8 +152,9 @@ class AbstractFetcherThreadTest {
     assertEquals(0L, replicaState.logEndOffset)
     assertEquals(0L, replicaState.highWatermark)
 
-    // After fencing, the fetcher should remove the partition from tracking
+    // After fencing, the fetcher should remove the partition from tracking and mark as failed
     assertTrue(fetcher.fetchState(partition).isEmpty)
+    assertTrue(failedPartitions.contains(partition))
   }
 
   @Test
@@ -176,8 +182,9 @@ class AbstractFetcherThreadTest {
 
     fetcher.doWork()
 
-    // After fencing, the fetcher should remove the partition from tracking
+    // After fencing, the fetcher should remove the partition from tracking and mark as failed
     assertTrue(fetcher.fetchState(partition).isEmpty)
+    assertTrue(failedPartitions.contains(partition))
   }
 
   @Test
@@ -480,11 +487,12 @@ class AbstractFetcherThreadTest {
     val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 4, highWatermark = 2L)
     fetcher.setLeaderState(partition, leaderState)
 
-    // After the out of range error, we get a fenced error and remove the partition
+    // After the out of range error, we get a fenced error and remove the partition and mark as failed
     fetcher.doWork()
     assertEquals(0, replicaState.logEndOffset)
     assertTrue(fetchedEarliestOffset)
     assertTrue(fetcher.fetchState(partition).isEmpty)
+    assertTrue(failedPartitions.contains(partition))
   }
 
   @Test
@@ -722,6 +730,67 @@ class AbstractFetcherThreadTest {
     }
   }
 
+  @Test
+  def testFetcherThreadHandlingPartitionFailureDuringAppending(): Unit = {
+    val fetcherForAppend = new MockFetcherThread {
+      override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = {
+        if (topicPartition == partition1) {
+          throw new KafkaException()
+        } else {
+          super.processPartitionData(topicPartition, fetchOffset, partitionData)
+        }
+      }
+    }
+    verifyFetcherThreadHandlingPartitionFailure(fetcherForAppend)
+  }
+
+  @Test
+  def testFetcherThreadHandlingPartitionFailureDuringTruncation(): Unit = {
+    val fetcherForTruncation = new MockFetcherThread {
+      override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {
+        if(topicPartition == partition1)
+          throw new Exception()
+        else {
+          super.truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState)
+        }
+      }
+    }
+    verifyFetcherThreadHandlingPartitionFailure(fetcherForTruncation)
+  }
+
+  private def verifyFetcherThreadHandlingPartitionFailure(fetcher: MockFetcherThread): Unit = {
+
+    fetcher.setReplicaState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.addPartitions(Map(partition1 -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.setLeaderState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0))
+
+    fetcher.setReplicaState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.addPartitions(Map(partition2 -> offsetAndEpoch(0L, leaderEpoch = 0)))
+    fetcher.setLeaderState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0))
+
+    // processing data fails for partition1
+    fetcher.doWork()
+
+    // partition1 marked as failed
+    assertTrue(failedPartitions.contains(partition1))
+    assertEquals(None, fetcher.fetchState(partition1))
+
+    // make sure the fetcher continues to work with rest of the partitions
+    fetcher.doWork()
+    assertEquals(Some(Fetching), fetcher.fetchState(partition2).map(_.state))
+    assertFalse(failedPartitions.contains(partition2))
+
+    // simulate a leader change
+    fetcher.removePartitions(Set(partition1))
+    failedPartitions.removeAll(Set(partition1))
+    fetcher.addPartitions(Map(partition1 -> offsetAndEpoch(0L, leaderEpoch = 1)))
+
+    // partition1 added back
+    assertEquals(Some(Truncating), fetcher.fetchState(partition1).map(_.state))
+    assertFalse(failedPartitions.contains(partition1))
+
+  }
+
   object MockFetcherThread {
     class PartitionState(var log: mutable.Buffer[RecordBatch],
                          var leaderEpoch: Int,
@@ -745,7 +814,8 @@ class AbstractFetcherThreadTest {
   class MockFetcherThread(val replicaId: Int = 0, val leaderId: Int = 1)
     extends AbstractFetcherThread("mock-fetcher",
       clientId = "mock-fetcher",
-      sourceBroker = new BrokerEndPoint(leaderId, host = "localhost", port = Random.nextInt())) {
+      sourceBroker = new BrokerEndPoint(leaderId, host = "localhost", port = Random.nextInt()),
+      failedPartitions) {
 
     import MockFetcherThread.PartitionState
 
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
index 9710cf2..4049504 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
@@ -39,6 +39,7 @@ class ReplicaAlterLogDirsThreadTest {
 
   private val t1p0 = new TopicPartition("topic1", 0)
   private val t1p1 = new TopicPartition("topic1", 1)
+  private val failedPartitions = new FailedPartitions
 
   private def offsetAndEpoch(fetchOffset: Long, leaderEpoch: Int = 1): OffsetAndEpoch = {
     OffsetAndEpoch(offset = fetchOffset, leaderEpoch = leaderEpoch)
@@ -78,6 +79,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
+      failedPartitions : FailedPartitions,
       replicaMgr = replicaManager,
       quota = null,
       brokerTopicStats = null)
@@ -123,6 +125,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
+      failedPartitions: FailedPartitions,
       replicaMgr = replicaManager,
       quota = null,
       brokerTopicStats = null)
@@ -204,6 +207,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
+      failedPartitions: FailedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -275,6 +279,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
+      failedPartitions : FailedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -330,6 +335,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
+      failedPartitions: FailedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -408,6 +414,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
+      failedPartitions: FailedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -467,6 +474,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
+      failedPartitions: FailedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -507,6 +515,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
+      failedPartitions: FailedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
@@ -556,6 +565,7 @@ class ReplicaAlterLogDirsThreadTest {
       "alter-logs-dirs-thread-test1",
       sourceBroker = endPoint,
       brokerConfig = config,
+      failedPartitions: FailedPartitions,
       replicaMgr = replicaManager,
       quota = quotaManager,
       brokerTopicStats = null)
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
index 2a89a29..3d7e86a 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
@@ -45,6 +45,7 @@ class ReplicaFetcherThreadTest {
   private val t2p1 = new TopicPartition("topic2", 1)
 
   private val brokerEndPoint = new BrokerEndPoint(0, "localhost", 1000)
+  private val failedPartitions = new FailedPartitions
 
   private def offsetAndEpoch(fetchOffset: Long, leaderEpoch: Int = 1): OffsetAndEpoch = {
     OffsetAndEpoch(offset = fetchOffset, leaderEpoch = leaderEpoch)
@@ -59,6 +60,7 @@ class ReplicaFetcherThreadTest {
       fetcherId = 0,
       sourceBroker = brokerEndPoint,
       brokerConfig = config,
+      failedPartitions: FailedPartitions,
       replicaMgr = null,
       metrics =  new Metrics(),
       time = new SystemTime(),
@@ -107,7 +109,7 @@ class ReplicaFetcherThreadTest {
     //Create the fetcher thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
 
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
 
     // topic 1 supports epoch, t2 doesn't
     thread.addPartitions(Map(
@@ -181,6 +183,7 @@ class ReplicaFetcherThreadTest {
       fetcherId = 0,
       sourceBroker = brokerEndPoint,
       brokerConfig = config,
+      failedPartitions: FailedPartitions,
       replicaMgr = null,
       metrics =  new Metrics(),
       time = new SystemTime(),
@@ -233,7 +236,7 @@ class ReplicaFetcherThreadTest {
 
     //Create the fetcher thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, replicaManager,
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager,
       new Metrics, new SystemTime, UnboundedQuota, Some(mockNetwork))
     thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
 
@@ -292,7 +295,7 @@ class ReplicaFetcherThreadTest {
 
     //Create the thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs(0), replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs(0), failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
     thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t2p1 -> offsetAndEpoch(0L)))
 
     //Run it
@@ -341,7 +344,7 @@ class ReplicaFetcherThreadTest {
 
     //Create the thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs(0), replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs(0), failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
     thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t2p1 -> offsetAndEpoch(0L)))
 
     //Run it
@@ -393,7 +396,7 @@ class ReplicaFetcherThreadTest {
 
     // Create the fetcher thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
     thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
 
     // Loop 1 -- both topic partitions will need to fetch another leader epoch
@@ -465,7 +468,7 @@ class ReplicaFetcherThreadTest {
 
     // Create the fetcher thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
     thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
 
     // Loop 1 -- both topic partitions will truncate to leader offset even though they don't know
@@ -521,7 +524,7 @@ class ReplicaFetcherThreadTest {
 
     //Create the thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs(0), replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs(0), failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
     thread.addPartitions(Map(t1p0 -> offsetAndEpoch(initialFetchOffset)))
 
     //Run it
@@ -571,7 +574,7 @@ class ReplicaFetcherThreadTest {
 
     //Create the thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs(0), replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, configs(0), failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
     thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
 
     //Run thread 3 times
@@ -624,7 +627,7 @@ class ReplicaFetcherThreadTest {
 
     //Create the fetcher thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
 
     //When
     thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
@@ -674,7 +677,7 @@ class ReplicaFetcherThreadTest {
 
     //Create the fetcher thread
     val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime())
-    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork))
 
     //When
     thread.addPartitions(Map(t1p0 -> offsetAndEpoch(0L), t1p1 -> offsetAndEpoch(0L)))
@@ -707,6 +710,7 @@ class ReplicaFetcherThreadTest {
       fetcherId = 0,
       sourceBroker = brokerEndPoint,
       brokerConfig = config,
+      failedPartitions = failedPartitions,
       replicaMgr = null,
       metrics =  new Metrics(),
       time = new SystemTime(),
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 65b62d0..1c1cbd6 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -676,7 +676,7 @@ class ReplicaManagerTest {
 
           override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): ReplicaFetcherThread = {
             new ReplicaFetcherThread(s"ReplicaFetcherThread-$fetcherId", fetcherId,
-              sourceBroker, config, replicaManager, metrics, time, quota.follower, Some(blockingSend)) {
+              sourceBroker, config, failedPartitions, replicaManager, metrics, time, quota.follower, Some(blockingSend)) {
 
               override def doWork() = {
                 // In case the thread starts before the partition is added by AbstractFetcherManager,