You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/04/25 05:23:37 UTC

[kafka] branch trunk updated: KAFKA-8237; Untangle TopicDeleteManager and add test cases (#6588)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 17c8016  KAFKA-8237; Untangle TopicDeleteManager and add test cases (#6588)
17c8016 is described below

commit 17c80166461c0005b7603414db4d0a3541df0f82
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Wed Apr 24 22:23:15 2019 -0700

    KAFKA-8237; Untangle TopicDeleteManager and add test cases (#6588)
    
    The controller maintains state across `ControllerContext`, `PartitionStateMachine`, `ReplicaStateMachine`, and `TopicDeletionManager`. None of this state is actually isolated from the rest. For example, topics undergoing deletion are intertwined with the partition and replica states. As a consequence of this, each of these components tends to be dependent on all the rest, which makes testing and reasoning about the system difficult. This is a first step toward untangling all the state [...]
    
    Additionally, this patch adds several mock objects to enable easier testing: `MockReplicaStateMachine` and `MockPartitionStateMachine`. These have simplified logic for updating the current state. This is used to create some new test cases for `TopicDeletionManager`.
    
    Reviewers: José Armando García Sancio <js...@users.noreply.github.com>, Jun Rao <ju...@gmail.com>
---
 .../controller/ControllerChannelManager.scala      |  21 +-
 .../scala/kafka/controller/ControllerContext.scala | 207 ++++++++++++++----
 .../src/main/scala/kafka/controller/Election.scala | 152 +++++++++++++
 .../scala/kafka/controller/KafkaController.scala   |  78 ++++---
 .../kafka/controller/PartitionStateMachine.scala   | 242 ++++++++-------------
 .../kafka/controller/ReplicaStateMachine.scala     | 186 ++++++++--------
 .../kafka/controller/TopicDeletionManager.scala    | 193 ++++++++--------
 .../scala/unit/kafka/admin/DeleteTopicTest.scala   |   4 +-
 .../kafka/controller/ControllerFailoverTest.scala  |   2 +-
 .../controller/MockPartitionStateMachine.scala     | 110 ++++++++++
 .../kafka/controller/MockReplicaStateMachine.scala |  36 +++
 .../controller/PartitionStateMachineTest.scala     | 118 ++++------
 .../kafka/controller/ReplicaStateMachineTest.scala |  61 +++---
 .../controller/TopicDeletionManagerTest.scala      | 232 ++++++++++++++++++++
 .../unit/kafka/server/LogDirFailureTest.scala      |   2 +-
 15 files changed, 1093 insertions(+), 551 deletions(-)

diff --git a/core/src/main/scala/kafka/controller/ControllerChannelManager.scala b/core/src/main/scala/kafka/controller/ControllerChannelManager.scala
index 3776b69..ca6c00a 100755
--- a/core/src/main/scala/kafka/controller/ControllerChannelManager.scala
+++ b/core/src/main/scala/kafka/controller/ControllerChannelManager.scala
@@ -61,7 +61,7 @@ class ControllerChannelManager(controllerContext: ControllerContext, config: Kaf
     }
   )
 
-  controllerContext.liveBrokers.foreach(addNewBroker)
+  controllerContext.liveOrShuttingDownBrokers.foreach(addNewBroker)
 
   def startup() = {
     brokerLock synchronized {
@@ -351,13 +351,23 @@ class ControllerBrokerRequestBatch(controller: KafkaController, stateChangeLogge
     addUpdateMetadataRequestForBrokers(controllerContext.liveOrShuttingDownBrokerIds.toSeq, Set(topicPartition))
   }
 
-  def addStopReplicaRequestForBrokers(brokerIds: Seq[Int], topicPartition: TopicPartition, deletePartition: Boolean,
-                                      callback: (AbstractResponse, Int) => Unit) {
+  def addStopReplicaRequestForBrokers(brokerIds: Seq[Int],
+                                      topicPartition: TopicPartition,
+                                      deletePartition: Boolean): Unit = {
     brokerIds.filter(_ >= 0).foreach { brokerId =>
+      def topicDeletionCallback(stopReplicaResponse: AbstractResponse): Unit = {
+        controller.eventManager.put(controller.TopicDeletionStopReplicaResponseReceived(stopReplicaResponse, brokerId))
+      }
+
+      val responseReceivedCallback = if (deletePartition && controllerContext.isTopicDeletionInProgress(topicPartition.topic))
+        topicDeletionCallback _
+      else
+        null
+
       stopReplicaRequestMap.getOrElseUpdate(brokerId, Seq.empty[StopReplicaRequestInfo])
       val v = stopReplicaRequestMap(brokerId)
       stopReplicaRequestMap(brokerId) = v :+ StopReplicaRequestInfo(PartitionAndReplica(topicPartition, brokerId),
-        deletePartition, (r: AbstractResponse) => callback(r, brokerId))
+        deletePartition, responseReceivedCallback)
     }
   }
 
@@ -394,7 +404,7 @@ class ControllerBrokerRequestBatch(controller: KafkaController, stateChangeLogge
 
     updateMetadataRequestBrokerSet ++= brokerIds.filter(_ >= 0)
     partitions.foreach(partition => updateMetadataRequestPartitionInfo(partition,
-      beingDeleted = controller.topicDeletionManager.topicsToBeDeleted.contains(partition.topic)))
+      beingDeleted = controllerContext.topicsToBeDeleted.contains(partition.topic)))
   }
 
   def sendRequestsToBrokers(controllerEpoch: Int) {
@@ -525,4 +535,3 @@ case class ControllerBrokerStateInfo(networkClient: NetworkClient,
 
 case class StopReplicaRequestInfo(replica: PartitionAndReplica, deletePartition: Boolean, callback: AbstractResponse => Unit)
 
-class Callbacks(val stopReplicaResponseCallback: (AbstractResponse, Int) => Unit = (_, _ ) => ())
diff --git a/core/src/main/scala/kafka/controller/ControllerContext.scala b/core/src/main/scala/kafka/controller/ControllerContext.scala
index c3bcc52..3069024 100644
--- a/core/src/main/scala/kafka/controller/ControllerContext.scala
+++ b/core/src/main/scala/kafka/controller/ControllerContext.scala
@@ -24,47 +24,75 @@ import scala.collection.{Seq, Set, mutable}
 
 class ControllerContext {
   val stats = new ControllerStats
-
-  var controllerChannelManager: ControllerChannelManager = null
-
+  var offlinePartitionCount = 0
   var shuttingDownBrokerIds: mutable.Set[Int] = mutable.Set.empty
+  private var liveBrokers: Set[Broker] = Set.empty
+  private var liveBrokerEpochs: Map[Int, Long] = Map.empty
   var epoch: Int = KafkaController.InitialControllerEpoch
   var epochZkVersion: Int = KafkaController.InitialControllerEpochZkVersion
+
   var allTopics: Set[String] = Set.empty
-  private val partitionReplicaAssignmentUnderlying: mutable.Map[String, mutable.Map[Int, Seq[Int]]] = mutable.Map.empty
-  val partitionLeadershipInfo: mutable.Map[TopicPartition, LeaderIsrAndControllerEpoch] = mutable.Map.empty
-  val partitionsBeingReassigned: mutable.Map[TopicPartition, ReassignedPartitionsContext] = mutable.Map.empty
+  val partitionAssignments = mutable.Map.empty[String, mutable.Map[Int, Seq[Int]]]
+  val partitionLeadershipInfo = mutable.Map.empty[TopicPartition, LeaderIsrAndControllerEpoch]
+  val partitionsBeingReassigned = mutable.Map.empty[TopicPartition, ReassignedPartitionsContext]
+  val partitionStates = mutable.Map.empty[TopicPartition, PartitionState]
+  val replicaStates = mutable.Map.empty[PartitionAndReplica, ReplicaState]
   val replicasOnOfflineDirs: mutable.Map[Int, Set[TopicPartition]] = mutable.Map.empty
 
-  private var liveBrokersUnderlying: Set[Broker] = Set.empty
-  private var liveBrokerIdAndEpochsUnderlying: Map[Int, Long] = Map.empty
+  val topicsToBeDeleted = mutable.Set.empty[String]
+
+  /** The following topicsWithDeletionStarted variable is used to properly update the offlinePartitionCount metric.
+   * When a topic is going through deletion, we don't want to keep track of its partition state
+   * changes in the offlinePartitionCount metric. This goal means if some partitions of a topic are already
+   * in OfflinePartition state when deletion starts, we need to change the corresponding partition
+   * states to NonExistentPartition first before starting the deletion.
+   *
+   * However we can NOT change partition states to NonExistentPartition at the time of enqueuing topics
+   * for deletion. The reason is that when a topic is enqueued for deletion, it may be ineligible for
+   * deletion due to ongoing partition reassignments. Hence there might be a delay between enqueuing
+   * a topic for deletion and the actual start of deletion. In this delayed interval, partitions may still
+   * transition to or out of the OfflinePartition state.
+   *
+   * Hence we decide to change partition states to NonExistentPartition only when the actual deletion have started.
+   * For topics whose deletion have actually started, we keep track of them in the following topicsWithDeletionStarted
+   * variable. And once a topic is in the topicsWithDeletionStarted set, we are sure there will no longer
+   * be partition reassignments to any of its partitions, and only then it's safe to move its partitions to
+   * NonExistentPartition state. Once a topic is in the topicsWithDeletionStarted set, we will stop monitoring
+   * its partition state changes in the offlinePartitionCount metric
+   */
+  val topicsWithDeletionStarted = mutable.Set.empty[String]
+  val topicsIneligibleForDeletion = mutable.Set.empty[String]
+
 
   def partitionReplicaAssignment(topicPartition: TopicPartition): Seq[Int] = {
-    partitionReplicaAssignmentUnderlying.getOrElse(topicPartition.topic, mutable.Map.empty)
+    partitionAssignments.getOrElse(topicPartition.topic, mutable.Map.empty)
       .getOrElse(topicPartition.partition, Seq.empty)
   }
 
   private def clearTopicsState(): Unit = {
     allTopics = Set.empty
-    partitionReplicaAssignmentUnderlying.clear()
+    partitionAssignments.clear()
     partitionLeadershipInfo.clear()
     partitionsBeingReassigned.clear()
     replicasOnOfflineDirs.clear()
+    partitionStates.clear()
+    offlinePartitionCount = 0
+    replicaStates.clear()
   }
 
   def updatePartitionReplicaAssignment(topicPartition: TopicPartition, newReplicas: Seq[Int]): Unit = {
-    partitionReplicaAssignmentUnderlying.getOrElseUpdate(topicPartition.topic, mutable.Map.empty)
+    partitionAssignments.getOrElseUpdate(topicPartition.topic, mutable.Map.empty)
       .put(topicPartition.partition, newReplicas)
   }
 
   def partitionReplicaAssignmentForTopic(topic : String): Map[TopicPartition, Seq[Int]] = {
-    partitionReplicaAssignmentUnderlying.getOrElse(topic, Map.empty).map {
+    partitionAssignments.getOrElse(topic, Map.empty).map {
       case (partition, replicas) => (new TopicPartition(topic, partition), replicas)
     }.toMap
   }
 
   def allPartitions: Set[TopicPartition] = {
-    partitionReplicaAssignmentUnderlying.flatMap {
+    partitionAssignments.flatMap {
       case (topic, topicReplicaAssignment) => topicReplicaAssignment.map {
         case (partition, _) => new TopicPartition(topic, partition)
       }
@@ -72,37 +100,36 @@ class ControllerContext {
   }
 
   def setLiveBrokerAndEpochs(brokerAndEpochs: Map[Broker, Long]) {
-    liveBrokersUnderlying = brokerAndEpochs.keySet
-    liveBrokerIdAndEpochsUnderlying =
+    liveBrokers = brokerAndEpochs.keySet
+    liveBrokerEpochs =
       brokerAndEpochs map { case (broker, brokerEpoch) => (broker.id, brokerEpoch)}
   }
 
   def addLiveBrokersAndEpochs(brokerAndEpochs: Map[Broker, Long]): Unit = {
-    liveBrokersUnderlying = liveBrokersUnderlying ++ brokerAndEpochs.keySet
-    liveBrokerIdAndEpochsUnderlying = liveBrokerIdAndEpochsUnderlying ++
+    liveBrokers = liveBrokers ++ brokerAndEpochs.keySet
+    liveBrokerEpochs = liveBrokerEpochs ++
       (brokerAndEpochs map { case (broker, brokerEpoch) => (broker.id, brokerEpoch)})
   }
 
-  def removeLiveBrokersAndEpochs(brokerIds : Set[Int]): Unit = {
-    liveBrokersUnderlying = liveBrokersUnderlying.filter(broker => !brokerIds.contains(broker.id))
-    liveBrokerIdAndEpochsUnderlying = liveBrokerIdAndEpochsUnderlying.filterKeys(id => !brokerIds.contains(id))
+  def removeLiveBrokers(brokerIds: Set[Int]): Unit = {
+    liveBrokers = liveBrokers.filter(broker => !brokerIds.contains(broker.id))
+    liveBrokerEpochs = liveBrokerEpochs.filterKeys(id => !brokerIds.contains(id))
   }
 
-  def updateBrokerMetadata(oldMetadata: Option[Broker], newMetadata: Option[Broker]): Unit = {
-    liveBrokersUnderlying = liveBrokersUnderlying -- oldMetadata ++ newMetadata
+  def updateBrokerMetadata(oldMetadata: Broker, newMetadata: Broker): Unit = {
+    liveBrokers -= oldMetadata
+    liveBrokers += newMetadata
   }
 
   // getter
-  def liveBrokers = liveBrokersUnderlying.filter(broker => !shuttingDownBrokerIds.contains(broker.id))
-  def liveBrokerIds = liveBrokerIdAndEpochsUnderlying.keySet -- shuttingDownBrokerIds
-
-  def liveOrShuttingDownBrokerIds = liveBrokerIdAndEpochsUnderlying.keySet
-  def liveOrShuttingDownBrokers = liveBrokersUnderlying
-
-  def liveBrokerIdAndEpochs = liveBrokerIdAndEpochsUnderlying
+  def liveBrokerIds: Set[Int] = liveBrokerEpochs.keySet -- shuttingDownBrokerIds
+  def liveOrShuttingDownBrokerIds: Set[Int] = liveBrokerEpochs.keySet
+  def liveOrShuttingDownBrokers: Set[Broker] = liveBrokers
+  def liveBrokerIdAndEpochs: Map[Int, Long] = liveBrokerEpochs
+  def liveOrShuttingDownBroker(brokerId: Int): Option[Broker] = liveOrShuttingDownBrokers.find(_.id == brokerId)
 
   def partitionsOnBroker(brokerId: Int): Set[TopicPartition] = {
-    partitionReplicaAssignmentUnderlying.flatMap {
+    partitionAssignments.flatMap {
       case (topic, topicReplicaAssignment) => topicReplicaAssignment.filter {
         case (_, replicas) => replicas.contains(brokerId)
       }.map {
@@ -121,7 +148,7 @@ class ControllerContext {
 
   def replicasOnBrokers(brokerIds: Set[Int]): Set[PartitionAndReplica] = {
     brokerIds.flatMap { brokerId =>
-      partitionReplicaAssignmentUnderlying.flatMap {
+      partitionAssignments.flatMap {
         case (topic, topicReplicaAssignment) => topicReplicaAssignment.collect {
           case (partition, replicas)  if replicas.contains(brokerId) =>
             PartitionAndReplica(new TopicPartition(topic, partition), brokerId)
@@ -131,13 +158,13 @@ class ControllerContext {
   }
 
   def replicasForTopic(topic: String): Set[PartitionAndReplica] = {
-    partitionReplicaAssignmentUnderlying.getOrElse(topic, mutable.Map.empty).flatMap {
+    partitionAssignments.getOrElse(topic, mutable.Map.empty).flatMap {
       case (partition, replicas) => replicas.map(r => PartitionAndReplica(new TopicPartition(topic, partition), r))
     }.toSet
   }
 
   def partitionsForTopic(topic: String): collection.Set[TopicPartition] = {
-    partitionReplicaAssignmentUnderlying.getOrElse(topic, mutable.Map.empty).map {
+    partitionAssignments.getOrElse(topic, mutable.Map.empty).map {
       case (partition, _) => new TopicPartition(topic, partition)
     }.toSet
   }
@@ -156,10 +183,9 @@ class ControllerContext {
   }
 
   def resetContext(): Unit = {
-    if (controllerChannelManager != null) {
-      controllerChannelManager.shutdown()
-      controllerChannelManager = null
-    }
+    topicsToBeDeleted.clear()
+    topicsWithDeletionStarted.clear()
+    topicsIneligibleForDeletion.clear()
     shuttingDownBrokerIds.clear()
     epoch = 0
     epochZkVersion = 0
@@ -169,10 +195,115 @@ class ControllerContext {
 
   def removeTopic(topic: String): Unit = {
     allTopics -= topic
-    partitionReplicaAssignmentUnderlying.remove(topic)
+    partitionAssignments.remove(topic)
     partitionLeadershipInfo.foreach {
       case (topicPartition, _) if topicPartition.topic == topic => partitionLeadershipInfo.remove(topicPartition)
       case _ =>
     }
   }
+
+  def beginTopicDeletion(topics: Set[String]): Unit = {
+    topicsWithDeletionStarted ++= topics
+  }
+
+  def isTopicDeletionInProgress(topic: String): Boolean = {
+    topicsWithDeletionStarted.contains(topic)
+  }
+
+  def isTopicQueuedUpForDeletion(topic: String): Boolean = {
+    topicsToBeDeleted.contains(topic)
+  }
+
+  def isTopicEligibleForDeletion(topic: String): Boolean = {
+    topicsToBeDeleted.contains(topic) && !topicsIneligibleForDeletion.contains(topic)
+  }
+
+  def topicsQueuedForDeletion: Set[String] = {
+    topicsToBeDeleted
+  }
+
+  def replicasInState(topic: String, state: ReplicaState): Set[PartitionAndReplica] = {
+    replicasForTopic(topic).filter(replica => replicaStates(replica) == state).toSet
+  }
+
+  def areAllReplicasInState(topic: String, state: ReplicaState): Boolean = {
+    replicasForTopic(topic).forall(replica => replicaStates(replica) == state)
+  }
+
+  def isAnyReplicaInState(topic: String, state: ReplicaState): Boolean = {
+    replicasForTopic(topic).exists(replica => replicaStates(replica) == state)
+  }
+
+  def checkValidReplicaStateChange(replicas: Seq[PartitionAndReplica], targetState: ReplicaState): (Seq[PartitionAndReplica], Seq[PartitionAndReplica]) = {
+    replicas.partition(replica => isValidReplicaStateTransition(replica, targetState))
+  }
+
+  def checkValidPartitionStateChange(partitions: Seq[TopicPartition], targetState: PartitionState): (Seq[TopicPartition], Seq[TopicPartition]) = {
+    partitions.partition(p => isValidPartitionStateTransition(p, targetState))
+  }
+
+  def putReplicaState(replica: PartitionAndReplica, state: ReplicaState): Unit = {
+    replicaStates.put(replica, state)
+  }
+
+  def removeReplicaState(replica: PartitionAndReplica): Unit = {
+    replicaStates.remove(replica)
+  }
+
+  def putReplicaStateIfNotExists(replica: PartitionAndReplica, state: ReplicaState): Unit = {
+    replicaStates.getOrElseUpdate(replica, state)
+  }
+
+  def putPartitionState(partition: TopicPartition, targetState: PartitionState): Unit = {
+    val currentState = partitionStates.put(partition, targetState).getOrElse(NonExistentPartition)
+    updatePartitionStateMetrics(partition, currentState, targetState)
+  }
+
+  private def updatePartitionStateMetrics(partition: TopicPartition,
+                                          currentState: PartitionState,
+                                          targetState: PartitionState): Unit = {
+    if (!isTopicDeletionInProgress(partition.topic)) {
+      if (currentState != OfflinePartition && targetState == OfflinePartition) {
+        offlinePartitionCount = offlinePartitionCount + 1
+      } else if (currentState == OfflinePartition && targetState != OfflinePartition) {
+        offlinePartitionCount = offlinePartitionCount - 1
+      }
+    }
+  }
+
+  def putPartitionStateIfNotExists(partition: TopicPartition, state: PartitionState): Unit = {
+    if (partitionStates.getOrElseUpdate(partition, state) == state)
+      updatePartitionStateMetrics(partition, NonExistentPartition, state)
+  }
+
+  def replicaState(replica: PartitionAndReplica): ReplicaState = {
+    replicaStates(replica)
+  }
+
+  def partitionState(partition: TopicPartition): PartitionState = {
+    partitionStates(partition)
+  }
+
+  def partitionsInState(state: PartitionState): Set[TopicPartition] = {
+    partitionStates.filter { case (_, s) => s == state }.keySet.toSet
+  }
+
+  def partitionsInStates(states: Set[PartitionState]): Set[TopicPartition] = {
+    partitionStates.filter { case (_, s) => states.contains(s) }.keySet.toSet
+  }
+
+  def partitionsInState(topic: String, state: PartitionState): Set[TopicPartition] = {
+    partitionsForTopic(topic).filter { partition => state == partitionState(partition) }.toSet
+  }
+
+  def partitionsInStates(topic: String, states: Set[PartitionState]): Set[TopicPartition] = {
+    partitionsForTopic(topic).filter { partition => states.contains(partitionState(partition)) }.toSet
+  }
+
+  private def isValidReplicaStateTransition(replica: PartitionAndReplica, targetState: ReplicaState): Boolean =
+    targetState.validPreviousStates.contains(replicaStates(replica))
+
+  private def isValidPartitionStateTransition(partition: TopicPartition, targetState: PartitionState): Boolean =
+    targetState.validPreviousStates.contains(partitionStates(partition))
+
 }
diff --git a/core/src/main/scala/kafka/controller/Election.scala b/core/src/main/scala/kafka/controller/Election.scala
new file mode 100644
index 0000000..9209992
--- /dev/null
+++ b/core/src/main/scala/kafka/controller/Election.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package kafka.controller
+
+import kafka.api.LeaderAndIsr
+import org.apache.kafka.common.TopicPartition
+
+case class ElectionResult(topicPartition: TopicPartition, leaderAndIsr: Option[LeaderAndIsr], liveReplicas: Seq[Int])
+
+object Election {
+
+  private def leaderForOffline(partition: TopicPartition,
+                               leaderIsrAndControllerEpochOpt: Option[LeaderIsrAndControllerEpoch],
+                               uncleanLeaderElectionEnabled: Boolean,
+                               controllerContext: ControllerContext): ElectionResult = {
+
+    val assignment = controllerContext.partitionReplicaAssignment(partition)
+    val liveReplicas = assignment.filter(replica => controllerContext.isReplicaOnline(replica, partition))
+    leaderIsrAndControllerEpochOpt match {
+      case Some(leaderIsrAndControllerEpoch) =>
+        val isr = leaderIsrAndControllerEpoch.leaderAndIsr.isr
+        val leaderOpt = PartitionLeaderElectionAlgorithms.offlinePartitionLeaderElection(assignment, isr,
+          liveReplicas.toSet, uncleanLeaderElectionEnabled, controllerContext)
+        val newLeaderAndIsrOpt = leaderOpt.map { leader =>
+          val newIsr = if (isr.contains(leader)) isr.filter(replica => controllerContext.isReplicaOnline(replica, partition))
+          else List(leader)
+          leaderIsrAndControllerEpoch.leaderAndIsr.newLeaderAndIsr(leader, newIsr)
+        }
+        ElectionResult(partition, newLeaderAndIsrOpt, liveReplicas)
+
+      case None =>
+        ElectionResult(partition, None, liveReplicas)
+    }
+  }
+
+  /**
+   * Elect leaders for new or offline partitions.
+   *
+   * @param controllerContext Context with the current state of the cluster
+   * @param partitionsWithUncleanLeaderElectionState A sequence of tuples representing the partitions
+   *                                                 that need election, their leader/ISR state, and whether
+   *                                                 or not unclean leader election is enabled
+   *
+   * @return The election results
+   */
+  def leaderForOffline(controllerContext: ControllerContext,
+                       partitionsWithUncleanLeaderElectionState: Seq[(TopicPartition, Option[LeaderIsrAndControllerEpoch], Boolean)]): Seq[ElectionResult] = {
+    partitionsWithUncleanLeaderElectionState.map { case (partition, leaderIsrAndControllerEpochOpt, uncleanLeaderElectionEnabled) =>
+      leaderForOffline(partition, leaderIsrAndControllerEpochOpt, uncleanLeaderElectionEnabled, controllerContext)
+    }
+  }
+
+  private def leaderForReassign(partition: TopicPartition,
+                                leaderIsrAndControllerEpoch: LeaderIsrAndControllerEpoch,
+                                controllerContext: ControllerContext): ElectionResult = {
+    val reassignment = controllerContext.partitionsBeingReassigned(partition).newReplicas
+    val liveReplicas = reassignment.filter(replica => controllerContext.isReplicaOnline(replica, partition))
+    val isr = leaderIsrAndControllerEpoch.leaderAndIsr.isr
+    val leaderOpt = PartitionLeaderElectionAlgorithms.reassignPartitionLeaderElection(reassignment, isr, liveReplicas.toSet)
+    val newLeaderAndIsrOpt = leaderOpt.map(leader => leaderIsrAndControllerEpoch.leaderAndIsr.newLeader(leader))
+    ElectionResult(partition, newLeaderAndIsrOpt, reassignment)
+  }
+
+  /**
+   * Elect leaders for partitions that are undergoing reassignment.
+   *
+   * @param controllerContext Context with the current state of the cluster
+   * @param leaderIsrAndControllerEpochs A sequence of tuples representing the partitions that need election
+   *                                     and their respective leader/ISR states
+   *
+   * @return The election results
+   */
+  def leaderForReassign(controllerContext: ControllerContext,
+                        leaderIsrAndControllerEpochs: Seq[(TopicPartition, LeaderIsrAndControllerEpoch)]): Seq[ElectionResult] = {
+    leaderIsrAndControllerEpochs.map { case (partition, leaderIsrAndControllerEpoch) =>
+      leaderForReassign(partition, leaderIsrAndControllerEpoch, controllerContext)
+    }
+  }
+
+  private def leaderForPreferredReplica(partition: TopicPartition,
+                                        leaderIsrAndControllerEpoch: LeaderIsrAndControllerEpoch,
+                                        controllerContext: ControllerContext): ElectionResult = {
+    val assignment = controllerContext.partitionReplicaAssignment(partition)
+    val liveReplicas = assignment.filter(replica => controllerContext.isReplicaOnline(replica, partition))
+    val isr = leaderIsrAndControllerEpoch.leaderAndIsr.isr
+    val leaderOpt = PartitionLeaderElectionAlgorithms.preferredReplicaPartitionLeaderElection(assignment, isr, liveReplicas.toSet)
+    val newLeaderAndIsrOpt = leaderOpt.map(leader => leaderIsrAndControllerEpoch.leaderAndIsr.newLeader(leader))
+    ElectionResult(partition, newLeaderAndIsrOpt, assignment)
+  }
+
+  /**
+   * Elect preferred leaders.
+   *
+   * @param controllerContext Context with the current state of the cluster
+   * @param leaderIsrAndControllerEpochs A sequence of tuples representing the partitions that need election
+   *                                     and their respective leader/ISR states
+   *
+   * @return The election results
+   */
+  def leaderForPreferredReplica(controllerContext: ControllerContext,
+                                leaderIsrAndControllerEpochs: Seq[(TopicPartition, LeaderIsrAndControllerEpoch)]): Seq[ElectionResult] = {
+    leaderIsrAndControllerEpochs.map { case (partition, leaderIsrAndControllerEpoch) =>
+      leaderForPreferredReplica(partition, leaderIsrAndControllerEpoch, controllerContext)
+    }
+  }
+
+  private def leaderForControlledShutdown(partition: TopicPartition,
+                                          leaderIsrAndControllerEpoch: LeaderIsrAndControllerEpoch,
+                                          shuttingDownBrokerIds: Set[Int],
+                                          controllerContext: ControllerContext): ElectionResult = {
+    val assignment = controllerContext.partitionReplicaAssignment(partition)
+    val liveOrShuttingDownReplicas = assignment.filter(replica =>
+      controllerContext.isReplicaOnline(replica, partition, includeShuttingDownBrokers = true))
+    val isr = leaderIsrAndControllerEpoch.leaderAndIsr.isr
+    val leaderOpt = PartitionLeaderElectionAlgorithms.controlledShutdownPartitionLeaderElection(assignment, isr,
+      liveOrShuttingDownReplicas.toSet, shuttingDownBrokerIds)
+    val newIsr = isr.filter(replica => !shuttingDownBrokerIds.contains(replica))
+    val newLeaderAndIsrOpt = leaderOpt.map(leader => leaderIsrAndControllerEpoch.leaderAndIsr.newLeaderAndIsr(leader, newIsr))
+    ElectionResult(partition, newLeaderAndIsrOpt, liveOrShuttingDownReplicas)
+  }
+
+  /**
+   * Elect leaders for partitions whose current leaders are shutting down.
+   *
+   * @param controllerContext Context with the current state of the cluster
+   * @param leaderIsrAndControllerEpochs A sequence of tuples representing the partitions that need election
+   *                                     and their respective leader/ISR states
+   *
+   * @return The election results
+   */
+  def leaderForControlledShutdown(controllerContext: ControllerContext,
+                                  leaderIsrAndControllerEpochs: Seq[(TopicPartition, LeaderIsrAndControllerEpoch)]): Seq[ElectionResult] = {
+    val shuttingDownBrokerIds = controllerContext.shuttingDownBrokerIds.toSet
+    leaderIsrAndControllerEpochs.map { case (partition, leaderIsrAndControllerEpoch) =>
+      leaderForControlledShutdown(partition, leaderIsrAndControllerEpoch, shuttingDownBrokerIds, controllerContext)
+    }
+  }
+}
diff --git a/core/src/main/scala/kafka/controller/KafkaController.scala b/core/src/main/scala/kafka/controller/KafkaController.scala
index ea23beb..183ffaf 100644
--- a/core/src/main/scala/kafka/controller/KafkaController.scala
+++ b/core/src/main/scala/kafka/controller/KafkaController.scala
@@ -37,6 +37,7 @@ import org.apache.kafka.common.requests.{AbstractControlRequest, AbstractRespons
 import org.apache.kafka.common.utils.Time
 import org.apache.zookeeper.KeeperException
 import org.apache.zookeeper.KeeperException.Code
+import scala.collection.JavaConverters._
 
 import scala.collection._
 import scala.util.{Failure, Try}
@@ -72,6 +73,7 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
 
   private val stateChangeLogger = new StateChangeLogger(config.brokerId, inControllerContext = true, None)
   val controllerContext = new ControllerContext
+  var controllerChannelManager: ControllerChannelManager = _
 
   // have a separate scheduler for the controller to be able to start and stop independently of the kafka server
   // visible for testing
@@ -81,11 +83,13 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
   private[controller] val eventManager = new ControllerEventManager(config.brokerId,
     controllerContext.stats.rateAndTimeMetrics, _ => updateMetrics(), () => maybeResign())
 
-  val topicDeletionManager = new TopicDeletionManager(this, eventManager, zkClient)
   private val brokerRequestBatch = new ControllerBrokerRequestBatch(this, stateChangeLogger)
-  val replicaStateMachine = new ReplicaStateMachine(config, stateChangeLogger, controllerContext, topicDeletionManager, zkClient, mutable.Map.empty, new ControllerBrokerRequestBatch(this, stateChangeLogger))
-  val partitionStateMachine = new PartitionStateMachine(config, stateChangeLogger, controllerContext, zkClient, mutable.Map.empty, new ControllerBrokerRequestBatch(this, stateChangeLogger))
-  partitionStateMachine.setTopicDeletionManager(topicDeletionManager)
+  val replicaStateMachine: ReplicaStateMachine = new ZkReplicaStateMachine(config, stateChangeLogger, controllerContext, zkClient,
+    new ControllerBrokerRequestBatch(this, stateChangeLogger))
+  val partitionStateMachine: PartitionStateMachine = new ZkPartitionStateMachine(config, stateChangeLogger, controllerContext, zkClient,
+    new ControllerBrokerRequestBatch(this, stateChangeLogger))
+  val topicDeletionManager = new TopicDeletionManager(config, controllerContext, replicaStateMachine,
+    partitionStateMachine, new ControllerDeletionClient(this, zkClient))
 
   private val controllerChangeHandler = new ControllerChangeHandler(this, eventManager)
   private val brokerChangeHandler = new BrokerChangeHandler(this, eventManager)
@@ -304,9 +308,6 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
     zkClient.unregisterZNodeChildChangeHandler(logDirEventNotificationHandler.path)
     unregisterBrokerModificationsHandler(brokerModificationsHandlers.keySet)
 
-    // reset topic deletion manager
-    topicDeletionManager.reset()
-
     // shutdown leader rebalance scheduler
     kafkaScheduler.shutdown()
     offlinePartitionCount = 0
@@ -329,6 +330,11 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
     replicaStateMachine.shutdown()
     zkClient.unregisterZNodeChildChangeHandler(brokerChangeHandler.path)
 
+
+    if (controllerChannelManager != null) {
+      controllerChannelManager.shutdown()
+      controllerChannelManager = null
+    }
     controllerContext.resetContext()
 
     info("Resigned")
@@ -390,7 +396,7 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
     val replicasForTopicsToBeDeleted = allReplicasOnNewBrokers.filter(p => topicDeletionManager.isTopicQueuedUpForDeletion(p.topic))
     if (replicasForTopicsToBeDeleted.nonEmpty) {
       info(s"Some replicas ${replicasForTopicsToBeDeleted.mkString(",")} for topics scheduled for deletion " +
-        s"${topicDeletionManager.topicsToBeDeleted.mkString(",")} are on the newly restarted brokers " +
+        s"${controllerContext.topicsToBeDeleted.mkString(",")} are on the newly restarted brokers " +
         s"${newBrokers.mkString(",")}. Signaling restart of topic deletion for these topics")
       topicDeletionManager.resumeDeletionForTopics(replicasForTopicsToBeDeleted.map(_.topic))
     }
@@ -486,7 +492,7 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
     info(s"New partition creation callback for ${newPartitions.mkString(",")}")
     partitionStateMachine.handleStateChanges(newPartitions.toSeq, NewPartition)
     replicaStateMachine.handleStateChanges(controllerContext.replicasForPartition(newPartitions).toSeq, NewReplica)
-    partitionStateMachine.handleStateChanges(newPartitions.toSeq, OnlinePartition, Option(OfflinePartitionLeaderElectionStrategy))
+    partitionStateMachine.handleStateChanges(newPartitions.toSeq, OnlinePartition, Some(OfflinePartitionLeaderElectionStrategy))
     replicaStateMachine.handleStateChanges(controllerContext.replicasForPartition(newPartitions).toSeq, OnlineReplica)
   }
 
@@ -646,7 +652,7 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
     info(s"Starting preferred replica leader election for partitions ${partitions.mkString(",")}")
     try {
       val results = partitionStateMachine.handleStateChanges(partitions.toSeq, OnlinePartition,
-        Option(PreferredReplicaPartitionLeaderElectionStrategy))
+        Some(PreferredReplicaPartitionLeaderElectionStrategy))
       if (electionType != AdminClientTriggered) {
         results.foreach { case (tp, throwable) =>
           if (throwable.isInstanceOf[ControllerMovedException]) {
@@ -677,7 +683,7 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
     controllerContext.partitionLeadershipInfo.clear()
     controllerContext.shuttingDownBrokerIds = mutable.Set.empty[Int]
     // register broker modifications handlers
-    registerBrokerModificationsHandler(controllerContext.liveBrokers.map(_.id))
+    registerBrokerModificationsHandler(controllerContext.liveOrShuttingDownBrokerIds)
     // update the leader and isr cache for all existing partitions from Zookeeper
     updateLeaderAndIsrCache()
     // start the channel manager
@@ -733,9 +739,9 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
   }
 
   private def startChannelManager() {
-    controllerContext.controllerChannelManager = new ControllerChannelManager(controllerContext, config, time, metrics,
+    controllerChannelManager = new ControllerChannelManager(controllerContext, config, time, metrics,
       stateChangeLogger, threadNamePrefix)
-    controllerContext.controllerChannelManager.startup()
+    controllerChannelManager.startup()
   }
 
   private def updateLeaderAndIsrCache(partitions: Seq[TopicPartition] = controllerContext.allPartitions.toSeq) {
@@ -763,7 +769,7 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
       info(s"Leader $currentLeader for partition $topicPartition being reassigned, " +
         s"is not in the new list of replicas ${reassignedReplicas.mkString(",")}. Re-electing leader")
       // move the leader to one of the alive and caught up new replicas
-      partitionStateMachine.handleStateChanges(Seq(topicPartition), OnlinePartition, Option(ReassignPartitionLeaderElectionStrategy))
+      partitionStateMachine.handleStateChanges(Seq(topicPartition), OnlinePartition, Some(ReassignPartitionLeaderElectionStrategy))
     } else {
       // check if the leader is alive or not
       if (controllerContext.isReplicaOnline(currentLeader, topicPartition)) {
@@ -774,7 +780,7 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
       } else {
         info(s"Leader $currentLeader for partition $topicPartition being reassigned, " +
           s"is already in the new list of replicas ${reassignedReplicas.mkString(",")} but is dead")
-        partitionStateMachine.handleStateChanges(Seq(topicPartition), OnlinePartition, Option(ReassignPartitionLeaderElectionStrategy))
+        partitionStateMachine.handleStateChanges(Seq(topicPartition), OnlinePartition, Some(ReassignPartitionLeaderElectionStrategy))
       }
     }
   }
@@ -912,7 +918,7 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
 
   private[controller] def sendRequest(brokerId: Int, apiKey: ApiKeys, request: AbstractControlRequest.Builder[_ <: AbstractControlRequest],
                                       callback: AbstractResponse => Unit = null) = {
-    controllerContext.controllerChannelManager.sendRequest(brokerId, apiKey, request, callback)
+    controllerChannelManager.sendRequest(brokerId, apiKey, request, callback)
   }
 
   /**
@@ -1091,12 +1097,11 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
       val (partitionsLedByBroker, partitionsFollowedByBroker) = partitionsToActOn.partition { partition =>
         controllerContext.partitionLeadershipInfo(partition).leaderAndIsr.leader == id
       }
-      partitionStateMachine.handleStateChanges(partitionsLedByBroker.toSeq, OnlinePartition, Option(ControlledShutdownPartitionLeaderElectionStrategy))
+      partitionStateMachine.handleStateChanges(partitionsLedByBroker.toSeq, OnlinePartition, Some(ControlledShutdownPartitionLeaderElectionStrategy))
       try {
         brokerRequestBatch.newBatch()
         partitionsFollowedByBroker.foreach { partition =>
-          brokerRequestBatch.addStopReplicaRequestForBrokers(Seq(id), partition, deletePartition = false,
-            (_, _) => ())
+          brokerRequestBatch.addStopReplicaRequestForBrokers(Seq(id), partition, deletePartition = false)
         }
         brokerRequestBatch.sendRequestsToBrokers(epoch)
       } catch {
@@ -1124,7 +1129,6 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
     def state = ControllerState.LeaderAndIsrResponseReceived
 
     override def process(): Unit = {
-      import JavaConverters._
       if (!isActive) return
       val leaderAndIsrResponse = LeaderAndIsrResponseObj.asInstanceOf[LeaderAndIsrResponse]
 
@@ -1156,11 +1160,10 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
     def state = ControllerState.TopicDeletion
 
     override def process(): Unit = {
-      import JavaConverters._
       if (!isActive) return
       val stopReplicaResponse = stopReplicaResponseObj.asInstanceOf[StopReplicaResponse]
-      debug(s"Delete topic callback invoked for $stopReplicaResponse")
       val responseMap = stopReplicaResponse.responses.asScala
+      debug(s"Delete topic callback invoked on StopReplica response received from broker $replicaId: $stopReplicaResponse")
       val partitionsInError =
         if (stopReplicaResponse.error != Errors.NONE) responseMap.keySet
         else responseMap.filter { case (_, error) => error != Errors.NONE }.keySet
@@ -1191,7 +1194,7 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
       if (!isActive) {
         0
       } else {
-        partitionStateMachine.offlinePartitionCount
+        controllerContext.offlinePartitionCount
       }
 
     preferredReplicaImbalanceCount =
@@ -1309,22 +1312,22 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
         s"bounced brokers: ${bouncedBrokerIdsSorted.mkString(",")}, " +
         s"all live brokers: ${liveBrokerIdsSorted.mkString(",")}")
 
-      newBrokerAndEpochs.keySet.foreach(controllerContext.controllerChannelManager.addBroker)
-      bouncedBrokerIds.foreach(controllerContext.controllerChannelManager.removeBroker)
-      bouncedBrokerAndEpochs.keySet.foreach(controllerContext.controllerChannelManager.addBroker)
-      deadBrokerIds.foreach(controllerContext.controllerChannelManager.removeBroker)
+      newBrokerAndEpochs.keySet.foreach(controllerChannelManager.addBroker)
+      bouncedBrokerIds.foreach(controllerChannelManager.removeBroker)
+      bouncedBrokerAndEpochs.keySet.foreach(controllerChannelManager.addBroker)
+      deadBrokerIds.foreach(controllerChannelManager.removeBroker)
       if (newBrokerIds.nonEmpty) {
         controllerContext.addLiveBrokersAndEpochs(newBrokerAndEpochs)
         onBrokerStartup(newBrokerIdsSorted)
       }
       if (bouncedBrokerIds.nonEmpty) {
-        controllerContext.removeLiveBrokersAndEpochs(bouncedBrokerIds)
+        controllerContext.removeLiveBrokers(bouncedBrokerIds)
         onBrokerFailure(bouncedBrokerIdsSorted)
         controllerContext.addLiveBrokersAndEpochs(bouncedBrokerAndEpochs)
         onBrokerStartup(bouncedBrokerIdsSorted)
       }
       if (deadBrokerIds.nonEmpty) {
-        controllerContext.removeLiveBrokersAndEpochs(deadBrokerIds)
+        controllerContext.removeLiveBrokers(deadBrokerIds)
         onBrokerFailure(deadBrokerIdsSorted)
       }
 
@@ -1339,13 +1342,16 @@ class KafkaController(val config: KafkaConfig, zkClient: KafkaZkClient, time: Ti
 
     override def process(): Unit = {
       if (!isActive) return
-      val newMetadata = zkClient.getBroker(brokerId)
-      val oldMetadata = controllerContext.liveBrokers.find(_.id == brokerId)
-      if (newMetadata.nonEmpty && oldMetadata.nonEmpty && newMetadata.map(_.endPoints) != oldMetadata.map(_.endPoints)) {
-        info(s"Updated broker: ${newMetadata.get}")
-
-        controllerContext.updateBrokerMetadata(oldMetadata, newMetadata)
-        onBrokerUpdate(brokerId)
+      val newMetadataOpt = zkClient.getBroker(brokerId)
+      val oldMetadataOpt = controllerContext.liveOrShuttingDownBroker(brokerId)
+      if (newMetadataOpt.nonEmpty && oldMetadataOpt.nonEmpty) {
+        val oldMetadata = oldMetadataOpt.get
+        val newMetadata = newMetadataOpt.get
+        if (newMetadata.endPoints != oldMetadata.endPoints) {
+          info(s"Updated broker metadata: $oldMetadata -> $newMetadata")
+          controllerContext.updateBrokerMetadata(oldMetadata, newMetadata)
+          onBrokerUpdate(brokerId)
+        }
       }
     }
   }
diff --git a/core/src/main/scala/kafka/controller/PartitionStateMachine.scala b/core/src/main/scala/kafka/controller/PartitionStateMachine.scala
index ad73979..637cea8 100755
--- a/core/src/main/scala/kafka/controller/PartitionStateMachine.scala
+++ b/core/src/main/scala/kafka/controller/PartitionStateMachine.scala
@@ -18,10 +18,11 @@ package kafka.controller
 
 import kafka.api.LeaderAndIsr
 import kafka.common.StateChangeFailedException
+import kafka.controller.Election._
 import kafka.server.KafkaConfig
 import kafka.utils.Logging
-import kafka.zk.{KafkaZkClient, TopicPartitionStateZNode}
 import kafka.zk.KafkaZkClient.UpdateLeaderAndIsrResult
+import kafka.zk.{KafkaZkClient, TopicPartitionStateZNode}
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.ControllerMovedException
 import org.apache.zookeeper.KeeperException
@@ -29,33 +30,7 @@ import org.apache.zookeeper.KeeperException.Code
 
 import scala.collection.mutable
 
-
-/**
- * This class represents the state machine for partitions. It defines the states that a partition can be in, and
- * transitions to move the partition to another legal state. The different states that a partition can be in are -
- * 1. NonExistentPartition: This state indicates that the partition was either never created or was created and then
- *                          deleted. Valid previous state, if one exists, is OfflinePartition
- * 2. NewPartition        : After creation, the partition is in the NewPartition state. In this state, the partition should have
- *                          replicas assigned to it, but no leader/isr yet. Valid previous states are NonExistentPartition
- * 3. OnlinePartition     : Once a leader is elected for a partition, it is in the OnlinePartition state.
- *                          Valid previous states are NewPartition/OfflinePartition
- * 4. OfflinePartition    : If, after successful leader election, the leader for partition dies, then the partition
- *                          moves to the OfflinePartition state. Valid previous states are NewPartition/OnlinePartition
- */
-class PartitionStateMachine(config: KafkaConfig,
-                            stateChangeLogger: StateChangeLogger,
-                            controllerContext: ControllerContext,
-                            zkClient: KafkaZkClient,
-                            partitionState: mutable.Map[TopicPartition, PartitionState],
-                            controllerBrokerRequestBatch: ControllerBrokerRequestBatch) extends Logging {
-  private val controllerId = config.brokerId
-
-  private var topicDeletionManager: TopicDeletionManager = _
-
-  this.logIdent = s"[PartitionStateMachine controllerId=$controllerId] "
-
-  var offlinePartitionCount = 0
-
+abstract class PartitionStateMachine(controllerContext: ControllerContext) extends Logging {
   /**
    * Invoked on successful controller election.
    */
@@ -64,20 +39,40 @@ class PartitionStateMachine(config: KafkaConfig,
     initializePartitionState()
     info("Triggering online partition state changes")
     triggerOnlinePartitionStateChange()
-    info(s"Started partition state machine with initial state -> $partitionState")
+    debug(s"Started partition state machine with initial state -> ${controllerContext.partitionStates}")
   }
 
   /**
    * Invoked on controller shutdown.
    */
   def shutdown() {
-    partitionState.clear()
-    offlinePartitionCount = 0
     info("Stopped partition state machine")
   }
 
-  def setTopicDeletionManager(topicDeletionManager: TopicDeletionManager) {
-    this.topicDeletionManager = topicDeletionManager
+  /**
+   * This API invokes the OnlinePartition state change on all partitions in either the NewPartition or OfflinePartition
+   * state. This is called on a successful controller election and on broker changes
+   */
+  def triggerOnlinePartitionStateChange(): Unit = {
+    val partitions = controllerContext.partitionsInStates(Set(OfflinePartition, NewPartition))
+    triggerOnlineStateChangeForPartitions(partitions)
+  }
+
+  def triggerOnlinePartitionStateChange(topic: String): Unit = {
+    val partitions = controllerContext.partitionsInStates(topic, Set(OfflinePartition, NewPartition))
+    triggerOnlineStateChangeForPartitions(partitions)
+  }
+
+  private def triggerOnlineStateChangeForPartitions(partitions: collection.Set[TopicPartition]): Unit = {
+    // try to move all partitions in NewPartition or OfflinePartition state to OnlinePartition state except partitions
+    // that belong to topics to be deleted
+    val partitionsToTrigger = partitions.filter { partition =>
+      !controllerContext.isTopicQueuedUpForDeletion(partition.topic)
+    }.toSeq
+
+    handleStateChanges(partitionsToTrigger, OnlinePartition, Some(OfflinePartitionLeaderElectionStrategy))
+    // TODO: If handleStateChanges catches an exception, it is not enough to bail out and log an error.
+    // It is important to trigger leader election for those partitions.
   }
 
   /**
@@ -92,38 +87,47 @@ class PartitionStateMachine(config: KafkaConfig,
           // else, check if the leader for partition is alive. If yes, it is in Online state, else it is in Offline state
           if (controllerContext.isReplicaOnline(currentLeaderIsrAndEpoch.leaderAndIsr.leader, topicPartition))
           // leader is alive
-            changeStateTo(topicPartition, NonExistentPartition, OnlinePartition)
+            controllerContext.putPartitionState(topicPartition, OnlinePartition)
           else
-            changeStateTo(topicPartition, NonExistentPartition, OfflinePartition)
+            controllerContext.putPartitionState(topicPartition, OfflinePartition)
         case None =>
-          changeStateTo(topicPartition, NonExistentPartition, NewPartition)
+          controllerContext.putPartitionState(topicPartition, NewPartition)
       }
     }
   }
 
-  /**
-   * This API invokes the OnlinePartition state change on all partitions in either the NewPartition or OfflinePartition
-   * state. This is called on a successful controller election and on broker changes
-   */
-  def triggerOnlinePartitionStateChange() {
-    triggerOnlinePartitionStateChange(partitionState.toMap)
+  def handleStateChanges(partitions: Seq[TopicPartition],
+                         targetState: PartitionState): Map[TopicPartition, Throwable] = {
+    handleStateChanges(partitions, targetState, None)
   }
 
-  def triggerOnlinePartitionStateChange(topic: String) {
-    triggerOnlinePartitionStateChange(partitionState.filterKeys(p => p.topic.equals(topic)).toMap)
-  }
+  def handleStateChanges(partitions: Seq[TopicPartition],
+                         targetState: PartitionState,
+                         leaderElectionStrategy: Option[PartitionLeaderElectionStrategy]): Map[TopicPartition, Throwable]
 
-  def triggerOnlinePartitionStateChange(partitionState: Map[TopicPartition, PartitionState]) {
-    // try to move all partitions in NewPartition or OfflinePartition state to OnlinePartition state except partitions
-    // that belong to topics to be deleted
-    val partitionsToTrigger = partitionState.filter { case (partition, partitionState) =>
-      !topicDeletionManager.isTopicQueuedUpForDeletion(partition.topic) &&
-        (partitionState.equals(OfflinePartition) || partitionState.equals(NewPartition))
-    }.keys.toSeq
-    handleStateChanges(partitionsToTrigger, OnlinePartition, Option(OfflinePartitionLeaderElectionStrategy))
-    // TODO: If handleStateChanges catches an exception, it is not enough to bail out and log an error.
-    // It is important to trigger leader election for those partitions.
-  }
+}
+
+/**
+ * This class represents the state machine for partitions. It defines the states that a partition can be in, and
+ * transitions to move the partition to another legal state. The different states that a partition can be in are -
+ * 1. NonExistentPartition: This state indicates that the partition was either never created or was created and then
+ *                          deleted. Valid previous state, if one exists, is OfflinePartition
+ * 2. NewPartition        : After creation, the partition is in the NewPartition state. In this state, the partition should have
+ *                          replicas assigned to it, but no leader/isr yet. Valid previous states are NonExistentPartition
+ * 3. OnlinePartition     : Once a leader is elected for a partition, it is in the OnlinePartition state.
+ *                          Valid previous states are NewPartition/OfflinePartition
+ * 4. OfflinePartition    : If, after successful leader election, the leader for partition dies, then the partition
+ *                          moves to the OfflinePartition state. Valid previous states are NewPartition/OnlinePartition
+ */
+class ZkPartitionStateMachine(config: KafkaConfig,
+                              stateChangeLogger: StateChangeLogger,
+                              controllerContext: ControllerContext,
+                              zkClient: KafkaZkClient,
+                              controllerBrokerRequestBatch: ControllerBrokerRequestBatch)
+  extends PartitionStateMachine(controllerContext) {
+
+  private val controllerId = config.brokerId
+  this.logIdent = s"[PartitionStateMachine controllerId=$controllerId] "
 
   /**
     * Try to change the state of the given partitions to the given targetState, using the given
@@ -133,8 +137,8 @@ class PartitionStateMachine(config: KafkaConfig,
     * @param partitionLeaderElectionStrategyOpt The leader election strategy if a leader election is required.
     * @return partitions and corresponding throwable for those partitions which could not transition to the given state
     */
-  def handleStateChanges(partitions: Seq[TopicPartition], targetState: PartitionState,
-                         partitionLeaderElectionStrategyOpt: Option[PartitionLeaderElectionStrategy] = None): Map[TopicPartition, Throwable] = {
+  override def handleStateChanges(partitions: Seq[TopicPartition], targetState: PartitionState,
+                         partitionLeaderElectionStrategyOpt: Option[PartitionLeaderElectionStrategy]): Map[TopicPartition, Throwable] = {
     if (partitions.nonEmpty) {
       try {
         controllerBrokerRequestBatch.newBatch()
@@ -154,24 +158,8 @@ class PartitionStateMachine(config: KafkaConfig,
     }
   }
 
-
-  def partitionsInState(state: PartitionState): Set[TopicPartition] = {
-    partitionState.filter { case (_, s) => s == state }.keySet.toSet
-  }
-
-  private def changeStateTo(partition: TopicPartition, currentState: PartitionState, targetState: PartitionState): Unit = {
-    partitionState.put(partition, targetState)
-    updateControllerMetrics(partition, currentState, targetState)
-  }
-
-  private def updateControllerMetrics(partition: TopicPartition, currentState: PartitionState, targetState: PartitionState) : Unit = {
-    if (!topicDeletionManager.isTopicWithDeletionStarted(partition.topic)) {
-      if (currentState != OfflinePartition && targetState == OfflinePartition) {
-        offlinePartitionCount = offlinePartitionCount + 1
-      } else if (currentState == OfflinePartition && targetState != OfflinePartition) {
-        offlinePartitionCount = offlinePartitionCount - 1
-      }
-    }
+  private def partitionState(partition: TopicPartition): PartitionState = {
+    controllerContext.partitionState(partition)
   }
 
   /**
@@ -196,18 +184,20 @@ class PartitionStateMachine(config: KafkaConfig,
    * @param partitions  The partitions for which the state transition is invoked
    * @param targetState The end state that the partition should be moved to
    */
-  private def doHandleStateChanges(partitions: Seq[TopicPartition], targetState: PartitionState,
-                           partitionLeaderElectionStrategyOpt: Option[PartitionLeaderElectionStrategy]): Map[TopicPartition, Throwable] = {
+  private def doHandleStateChanges(partitions: Seq[TopicPartition],
+                                   targetState: PartitionState,
+                                   partitionLeaderElectionStrategyOpt: Option[PartitionLeaderElectionStrategy]): Map[TopicPartition, Throwable] = {
     val stateChangeLog = stateChangeLogger.withControllerEpoch(controllerContext.epoch)
-    partitions.foreach(partition => partitionState.getOrElseUpdate(partition, NonExistentPartition))
-    val (validPartitions, invalidPartitions) = partitions.partition(partition => isValidTransition(partition, targetState))
+    partitions.foreach(partition => controllerContext.putPartitionStateIfNotExists(partition, NonExistentPartition))
+    val (validPartitions, invalidPartitions) = controllerContext.checkValidPartitionStateChange(partitions, targetState)
     invalidPartitions.foreach(partition => logInvalidTransition(partition, targetState))
+
     targetState match {
       case NewPartition =>
         validPartitions.foreach { partition =>
           stateChangeLog.trace(s"Changed partition $partition state from ${partitionState(partition)} to $targetState with " +
             s"assigned replicas ${controllerContext.partitionReplicaAssignment(partition).mkString(",")}")
-          changeStateTo(partition, partitionState(partition), NewPartition)
+          controllerContext.putPartitionState(partition, NewPartition)
         }
         Map.empty
       case OnlinePartition =>
@@ -218,7 +208,7 @@ class PartitionStateMachine(config: KafkaConfig,
           successfulInitializations.foreach { partition =>
             stateChangeLog.trace(s"Changed partition $partition from ${partitionState(partition)} to $targetState with state " +
               s"${controllerContext.partitionLeadershipInfo(partition).leaderAndIsr}")
-            changeStateTo(partition, partitionState(partition), OnlinePartition)
+            controllerContext.putPartitionState(partition, OnlinePartition)
           }
         }
         if (partitionsToElectLeader.nonEmpty) {
@@ -226,7 +216,7 @@ class PartitionStateMachine(config: KafkaConfig,
           successfulElections.foreach { partition =>
             stateChangeLog.trace(s"Changed partition $partition from ${partitionState(partition)} to $targetState with state " +
               s"${controllerContext.partitionLeadershipInfo(partition).leaderAndIsr}")
-            changeStateTo(partition, partitionState(partition), OnlinePartition)
+            controllerContext.putPartitionState(partition, OnlinePartition)
           }
           failedElections
         } else {
@@ -235,13 +225,13 @@ class PartitionStateMachine(config: KafkaConfig,
       case OfflinePartition =>
         validPartitions.foreach { partition =>
           stateChangeLog.trace(s"Changed partition $partition state from ${partitionState(partition)} to $targetState")
-          changeStateTo(partition, partitionState(partition), OfflinePartition)
+          controllerContext.putPartitionState(partition, OfflinePartition)
         }
         Map.empty
       case NonExistentPartition =>
         validPartitions.foreach { partition =>
           stateChangeLog.trace(s"Changed partition $partition state from ${partitionState(partition)} to $targetState")
-          changeStateTo(partition, partitionState(partition), NonExistentPartition)
+          controllerContext.putPartitionState(partition, NonExistentPartition)
         }
         Map.empty
     }
@@ -374,23 +364,24 @@ class PartitionStateMachine(config: KafkaConfig,
     if (validPartitionsForElection.isEmpty) {
       return (Seq.empty, Seq.empty, failedElections.toMap)
     }
-    val shuttingDownBrokers  = controllerContext.shuttingDownBrokerIds.toSet
     val (partitionsWithoutLeaders, partitionsWithLeaders) = partitionLeaderElectionStrategy match {
       case OfflinePartitionLeaderElectionStrategy =>
-        leaderForOffline(validPartitionsForElection).partition { case (_, newLeaderAndIsrOpt, _) => newLeaderAndIsrOpt.isEmpty }
+        val partitionsWithUncleanLeaderElectionState = collectUncleanLeaderElectionState(validPartitionsForElection)
+        leaderForOffline(controllerContext, partitionsWithUncleanLeaderElectionState).partition(_.leaderAndIsr.isEmpty)
       case ReassignPartitionLeaderElectionStrategy =>
-        leaderForReassign(validPartitionsForElection).partition { case (_, newLeaderAndIsrOpt, _) => newLeaderAndIsrOpt.isEmpty }
+        leaderForReassign(controllerContext, validPartitionsForElection).partition(_.leaderAndIsr.isEmpty)
       case PreferredReplicaPartitionLeaderElectionStrategy =>
-        leaderForPreferredReplica(validPartitionsForElection).partition { case (_, newLeaderAndIsrOpt, _) => newLeaderAndIsrOpt.isEmpty }
+        leaderForPreferredReplica(controllerContext, validPartitionsForElection).partition(_.leaderAndIsr.isEmpty)
       case ControlledShutdownPartitionLeaderElectionStrategy =>
-        leaderForControlledShutdown(validPartitionsForElection, shuttingDownBrokers).partition { case (_, newLeaderAndIsrOpt, _) => newLeaderAndIsrOpt.isEmpty }
+        leaderForControlledShutdown(controllerContext, validPartitionsForElection).partition(_.leaderAndIsr.isEmpty)
     }
-    partitionsWithoutLeaders.foreach { case (partition, _, _) =>
+    partitionsWithoutLeaders.foreach { electionResult =>
+      val partition = electionResult.topicPartition
       val failMsg = s"Failed to elect leader for partition $partition under strategy $partitionLeaderElectionStrategy"
       failedElections.put(partition, new StateChangeFailedException(failMsg))
     }
-    val recipientsPerPartition = partitionsWithLeaders.map { case (partition, _, recipients) => partition -> recipients }.toMap
-    val adjustedLeaderAndIsrs = partitionsWithLeaders.map { case (partition, leaderAndIsrOpt, _) => partition -> leaderAndIsrOpt.get }.toMap
+    val recipientsPerPartition = partitionsWithLeaders.map(result => result.topicPartition -> result.liveReplicas).toMap
+    val adjustedLeaderAndIsrs = partitionsWithLeaders.map(result => result.topicPartition -> result.leaderAndIsr.get).toMap
     val UpdateLeaderAndIsrResult(successfulUpdates, updatesToRetry, failedUpdates) = zkClient.updateLeaderAndIsr(
       adjustedLeaderAndIsrs, controllerContext.epoch, controllerContext.epochZkVersion)
     successfulUpdates.foreach { case (partition, leaderAndIsr) =>
@@ -403,14 +394,14 @@ class PartitionStateMachine(config: KafkaConfig,
     (successfulUpdates.keys.toSeq, updatesToRetry, failedElections.toMap ++ failedUpdates)
   }
 
-  private def leaderForOffline(leaderIsrAndControllerEpochs: Seq[(TopicPartition, LeaderIsrAndControllerEpoch)]):
-  Seq[(TopicPartition, Option[LeaderAndIsr], Seq[Int])] = {
+  private def collectUncleanLeaderElectionState(leaderIsrAndControllerEpochs: Seq[(TopicPartition, LeaderIsrAndControllerEpoch)]):
+  Seq[(TopicPartition, Option[LeaderIsrAndControllerEpoch], Boolean)] = {
     val (partitionsWithNoLiveInSyncReplicas, partitionsWithLiveInSyncReplicas) = leaderIsrAndControllerEpochs.partition { case (partition, leaderIsrAndControllerEpoch) =>
       val liveInSyncReplicas = leaderIsrAndControllerEpoch.leaderAndIsr.isr.filter(replica => controllerContext.isReplicaOnline(replica, partition))
       liveInSyncReplicas.isEmpty
     }
     val (logConfigs, failed) = zkClient.getLogConfigs(partitionsWithNoLiveInSyncReplicas.map { case (partition, _) => partition.topic }, config.originals())
-    val partitionsWithUncleanLeaderElectionState = partitionsWithNoLiveInSyncReplicas.map { case (partition, leaderIsrAndControllerEpoch) =>
+    partitionsWithNoLiveInSyncReplicas.map { case (partition, leaderIsrAndControllerEpoch) =>
       if (failed.contains(partition.topic)) {
         logFailedStateChange(partition, partitionState(partition), OnlinePartition, failed(partition.topic))
         (partition, None, false)
@@ -418,65 +409,8 @@ class PartitionStateMachine(config: KafkaConfig,
         (partition, Option(leaderIsrAndControllerEpoch), logConfigs(partition.topic).uncleanLeaderElectionEnable.booleanValue())
       }
     } ++ partitionsWithLiveInSyncReplicas.map { case (partition, leaderIsrAndControllerEpoch) => (partition, Option(leaderIsrAndControllerEpoch), false) }
-    partitionsWithUncleanLeaderElectionState.map { case (partition, leaderIsrAndControllerEpochOpt, uncleanLeaderElectionEnabled) =>
-      val assignment = controllerContext.partitionReplicaAssignment(partition)
-      val liveReplicas = assignment.filter(replica => controllerContext.isReplicaOnline(replica, partition))
-      if (leaderIsrAndControllerEpochOpt.nonEmpty) {
-        val leaderIsrAndControllerEpoch = leaderIsrAndControllerEpochOpt.get
-        val isr = leaderIsrAndControllerEpoch.leaderAndIsr.isr
-        val leaderOpt = PartitionLeaderElectionAlgorithms.offlinePartitionLeaderElection(assignment, isr, liveReplicas.toSet, uncleanLeaderElectionEnabled, controllerContext)
-        val newLeaderAndIsrOpt = leaderOpt.map { leader =>
-          val newIsr = if (isr.contains(leader)) isr.filter(replica => controllerContext.isReplicaOnline(replica, partition))
-          else List(leader)
-          leaderIsrAndControllerEpoch.leaderAndIsr.newLeaderAndIsr(leader, newIsr)
-        }
-        (partition, newLeaderAndIsrOpt, liveReplicas)
-      } else {
-        (partition, None, liveReplicas)
-      }
-    }
   }
 
-  private def leaderForReassign(leaderIsrAndControllerEpochs: Seq[(TopicPartition, LeaderIsrAndControllerEpoch)]):
-  Seq[(TopicPartition, Option[LeaderAndIsr], Seq[Int])] = {
-    leaderIsrAndControllerEpochs.map { case (partition, leaderIsrAndControllerEpoch) =>
-      val reassignment = controllerContext.partitionsBeingReassigned(partition).newReplicas
-      val liveReplicas = reassignment.filter(replica => controllerContext.isReplicaOnline(replica, partition))
-      val isr = leaderIsrAndControllerEpoch.leaderAndIsr.isr
-      val leaderOpt = PartitionLeaderElectionAlgorithms.reassignPartitionLeaderElection(reassignment, isr, liveReplicas.toSet)
-      val newLeaderAndIsrOpt = leaderOpt.map(leader => leaderIsrAndControllerEpoch.leaderAndIsr.newLeader(leader))
-      (partition, newLeaderAndIsrOpt, reassignment)
-    }
-  }
-
-  private def leaderForPreferredReplica(leaderIsrAndControllerEpochs: Seq[(TopicPartition, LeaderIsrAndControllerEpoch)]):
-  Seq[(TopicPartition, Option[LeaderAndIsr], Seq[Int])] = {
-    leaderIsrAndControllerEpochs.map { case (partition, leaderIsrAndControllerEpoch) =>
-      val assignment = controllerContext.partitionReplicaAssignment(partition)
-      val liveReplicas = assignment.filter(replica => controllerContext.isReplicaOnline(replica, partition))
-      val isr = leaderIsrAndControllerEpoch.leaderAndIsr.isr
-      val leaderOpt = PartitionLeaderElectionAlgorithms.preferredReplicaPartitionLeaderElection(assignment, isr, liveReplicas.toSet)
-      val newLeaderAndIsrOpt = leaderOpt.map(leader => leaderIsrAndControllerEpoch.leaderAndIsr.newLeader(leader))
-      (partition, newLeaderAndIsrOpt, assignment)
-    }
-  }
-
-  private def leaderForControlledShutdown(leaderIsrAndControllerEpochs: Seq[(TopicPartition, LeaderIsrAndControllerEpoch)], shuttingDownBrokers: Set[Int]):
-  Seq[(TopicPartition, Option[LeaderAndIsr], Seq[Int])] = {
-    leaderIsrAndControllerEpochs.map { case (partition, leaderIsrAndControllerEpoch) =>
-      val assignment = controllerContext.partitionReplicaAssignment(partition)
-      val liveOrShuttingDownReplicas = assignment.filter(replica => controllerContext.isReplicaOnline(replica, partition, includeShuttingDownBrokers = true))
-      val isr = leaderIsrAndControllerEpoch.leaderAndIsr.isr
-      val leaderOpt = PartitionLeaderElectionAlgorithms.controlledShutdownPartitionLeaderElection(assignment, isr, liveOrShuttingDownReplicas.toSet, shuttingDownBrokers)
-      val newIsr = isr.filter(replica => !controllerContext.shuttingDownBrokerIds.contains(replica))
-      val newLeaderAndIsrOpt = leaderOpt.map(leader => leaderIsrAndControllerEpoch.leaderAndIsr.newLeaderAndIsr(leader, newIsr))
-      (partition, newLeaderAndIsrOpt, liveOrShuttingDownReplicas)
-    }
-  }
-
-  private def isValidTransition(partition: TopicPartition, targetState: PartitionState) =
-    targetState.validPreviousStates.contains(partitionState(partition))
-
   private def logInvalidTransition(partition: TopicPartition, targetState: PartitionState): Unit = {
     val currState = partitionState(partition)
     val e = new IllegalStateException(s"Partition $partition should be in one of " +
@@ -501,7 +435,7 @@ object PartitionLeaderElectionAlgorithms {
     assignment.find(id => liveReplicas.contains(id) && isr.contains(id)).orElse {
       if (uncleanLeaderElectionEnabled) {
         val leaderOpt = assignment.find(liveReplicas.contains)
-        if (!leaderOpt.isEmpty)
+        if (leaderOpt.isDefined)
           controllerContext.stats.uncleanLeaderElectionRate.mark()
         leaderOpt
       } else {
diff --git a/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala b/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala
index 433ab56..f7ec470 100644
--- a/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala
+++ b/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala
@@ -28,35 +28,7 @@ import org.apache.zookeeper.KeeperException.Code
 
 import scala.collection.mutable
 
-/**
- * This class represents the state machine for replicas. It defines the states that a replica can be in, and
- * transitions to move the replica to another legal state. The different states that a replica can be in are -
- * 1. NewReplica        : The controller can create new replicas during partition reassignment. In this state, a
- *                        replica can only get become follower state change request.  Valid previous
- *                        state is NonExistentReplica
- * 2. OnlineReplica     : Once a replica is started and part of the assigned replicas for its partition, it is in this
- *                        state. In this state, it can get either become leader or become follower state change requests.
- *                        Valid previous state are NewReplica, OnlineReplica or OfflineReplica
- * 3. OfflineReplica    : If a replica dies, it moves to this state. This happens when the broker hosting the replica
- *                        is down. Valid previous state are NewReplica, OnlineReplica
- * 4. ReplicaDeletionStarted: If replica deletion starts, it is moved to this state. Valid previous state is OfflineReplica
- * 5. ReplicaDeletionSuccessful: If replica responds with no error code in response to a delete replica request, it is
- *                        moved to this state. Valid previous state is ReplicaDeletionStarted
- * 6. ReplicaDeletionIneligible: If replica deletion fails, it is moved to this state. Valid previous state is ReplicaDeletionStarted
- * 7. NonExistentReplica: If a replica is deleted successfully, it is moved to this state. Valid previous state is
- *                        ReplicaDeletionSuccessful
- */
-class ReplicaStateMachine(config: KafkaConfig,
-                          stateChangeLogger: StateChangeLogger,
-                          controllerContext: ControllerContext,
-                          topicDeletionManager: TopicDeletionManager,
-                          zkClient: KafkaZkClient,
-                          replicaState: mutable.Map[PartitionAndReplica, ReplicaState],
-                          controllerBrokerRequestBatch: ControllerBrokerRequestBatch) extends Logging {
-  private val controllerId = config.brokerId
-
-  this.logIdent = s"[ReplicaStateMachine controllerId=$controllerId] "
-
+abstract class ReplicaStateMachine(controllerContext: ControllerContext) extends Logging {
   /**
    * Invoked on successful controller election.
    */
@@ -65,14 +37,13 @@ class ReplicaStateMachine(config: KafkaConfig,
     initializeReplicaState()
     info("Triggering online replica state changes")
     handleStateChanges(controllerContext.allLiveReplicas().toSeq, OnlineReplica)
-    info(s"Started replica state machine with initial state -> $replicaState")
+    debug(s"Started replica state machine with initial state -> ${controllerContext.replicaStates}")
   }
 
   /**
    * Invoked on controller shutdown.
    */
   def shutdown() {
-    replicaState.clear()
     info("Stopped replica state machine")
   }
 
@@ -85,25 +56,56 @@ class ReplicaStateMachine(config: KafkaConfig,
       val replicas = controllerContext.partitionReplicaAssignment(partition)
       replicas.foreach { replicaId =>
         val partitionAndReplica = PartitionAndReplica(partition, replicaId)
-        if (controllerContext.isReplicaOnline(replicaId, partition))
-          replicaState.put(partitionAndReplica, OnlineReplica)
-        else
-        // mark replicas on dead brokers as failed for topic deletion, if they belong to a topic to be deleted.
-        // This is required during controller failover since during controller failover a broker can go down,
-        // so the replicas on that broker should be moved to ReplicaDeletionIneligible to be on the safer side.
-          replicaState.put(partitionAndReplica, ReplicaDeletionIneligible)
+        if (controllerContext.isReplicaOnline(replicaId, partition)) {
+          controllerContext.putReplicaState(partitionAndReplica, OnlineReplica)
+        } else {
+          // mark replicas on dead brokers as failed for topic deletion, if they belong to a topic to be deleted.
+          // This is required during controller failover since during controller failover a broker can go down,
+          // so the replicas on that broker should be moved to ReplicaDeletionIneligible to be on the safer side.
+          controllerContext.putReplicaState(partitionAndReplica, ReplicaDeletionIneligible)
+        }
       }
     }
   }
 
-  def handleStateChanges(replicas: Seq[PartitionAndReplica], targetState: ReplicaState,
-                         callbacks: Callbacks = new Callbacks()): Unit = {
+  def handleStateChanges(replicas: Seq[PartitionAndReplica], targetState: ReplicaState): Unit
+}
+
+/**
+ * This class represents the state machine for replicas. It defines the states that a replica can be in, and
+ * transitions to move the replica to another legal state. The different states that a replica can be in are -
+ * 1. NewReplica        : The controller can create new replicas during partition reassignment. In this state, a
+ *                        replica can only get become follower state change request.  Valid previous
+ *                        state is NonExistentReplica
+ * 2. OnlineReplica     : Once a replica is started and part of the assigned replicas for its partition, it is in this
+ *                        state. In this state, it can get either become leader or become follower state change requests.
+ *                        Valid previous state are NewReplica, OnlineReplica or OfflineReplica
+ * 3. OfflineReplica    : If a replica dies, it moves to this state. This happens when the broker hosting the replica
+ *                        is down. Valid previous state are NewReplica, OnlineReplica
+ * 4. ReplicaDeletionStarted: If replica deletion starts, it is moved to this state. Valid previous state is OfflineReplica
+ * 5. ReplicaDeletionSuccessful: If replica responds with no error code in response to a delete replica request, it is
+ *                        moved to this state. Valid previous state is ReplicaDeletionStarted
+ * 6. ReplicaDeletionIneligible: If replica deletion fails, it is moved to this state. Valid previous states are
+ *                        ReplicaDeletionStarted and OfflineReplica
+ * 7. NonExistentReplica: If a replica is deleted successfully, it is moved to this state. Valid previous state is
+ *                        ReplicaDeletionSuccessful
+ */
+class ZkReplicaStateMachine(config: KafkaConfig,
+                            stateChangeLogger: StateChangeLogger,
+                            controllerContext: ControllerContext,
+                            zkClient: KafkaZkClient,
+                            controllerBrokerRequestBatch: ControllerBrokerRequestBatch)
+  extends ReplicaStateMachine(controllerContext) with Logging {
+
+  private val controllerId = config.brokerId
+  this.logIdent = s"[ReplicaStateMachine controllerId=$controllerId] "
+
+  override def handleStateChanges(replicas: Seq[PartitionAndReplica], targetState: ReplicaState): Unit = {
     if (replicas.nonEmpty) {
       try {
         controllerBrokerRequestBatch.newBatch()
-        replicas.groupBy(_.replica).map { case (replicaId, replicas) =>
-          val partitions = replicas.map(_.topicPartition)
-          doHandleStateChanges(replicaId, partitions, targetState, callbacks)
+        replicas.groupBy(_.replica).foreach { case (replicaId, replicas) =>
+          doHandleStateChanges(replicaId, replicas, targetState)
         }
         controllerBrokerRequestBatch.sendRequestsToBrokers(controllerContext.epoch)
       } catch {
@@ -150,39 +152,42 @@ class ReplicaStateMachine(config: KafkaConfig,
    * @param partitions The partitions on this replica for which the state transition is invoked
    * @param targetState The end state that the replica should be moved to
    */
-  private def doHandleStateChanges(replicaId: Int, partitions: Seq[TopicPartition], targetState: ReplicaState,
-                                   callbacks: Callbacks): Unit = {
-    val replicas = partitions.map(partition => PartitionAndReplica(partition, replicaId))
-    replicas.foreach(replica => replicaState.getOrElseUpdate(replica, NonExistentReplica))
-    val (validReplicas, invalidReplicas) = replicas.partition(replica => isValidTransition(replica, targetState))
+  private def doHandleStateChanges(replicaId: Int, replicas: Seq[PartitionAndReplica], targetState: ReplicaState): Unit = {
+    replicas.foreach(replica => controllerContext.putReplicaStateIfNotExists(replica, NonExistentReplica))
+    val (validReplicas, invalidReplicas) = controllerContext.checkValidReplicaStateChange(replicas, targetState)
     invalidReplicas.foreach(replica => logInvalidTransition(replica, targetState))
+
     targetState match {
       case NewReplica =>
         validReplicas.foreach { replica =>
           val partition = replica.topicPartition
+          val currentState = controllerContext.replicaState(replica)
+
           controllerContext.partitionLeadershipInfo.get(partition) match {
             case Some(leaderIsrAndControllerEpoch) =>
               if (leaderIsrAndControllerEpoch.leaderAndIsr.leader == replicaId) {
                 val exception = new StateChangeFailedException(s"Replica $replicaId for partition $partition cannot be moved to NewReplica state as it is being requested to become leader")
-                logFailedStateChange(replica, replicaState(replica), OfflineReplica, exception)
+                logFailedStateChange(replica, currentState, OfflineReplica, exception)
               } else {
                 controllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(replicaId),
                   replica.topicPartition,
                   leaderIsrAndControllerEpoch,
                   controllerContext.partitionReplicaAssignment(replica.topicPartition),
                   isNew = true)
-                logSuccessfulTransition(replicaId, partition, replicaState(replica), NewReplica)
-                replicaState.put(replica, NewReplica)
+                logSuccessfulTransition(replicaId, partition, currentState, NewReplica)
+                controllerContext.putReplicaState(replica, NewReplica)
               }
             case None =>
-              logSuccessfulTransition(replicaId, partition, replicaState(replica), NewReplica)
-              replicaState.put(replica, NewReplica)
+              logSuccessfulTransition(replicaId, partition, currentState, NewReplica)
+              controllerContext.putReplicaState(replica, NewReplica)
           }
         }
       case OnlineReplica =>
         validReplicas.foreach { replica =>
           val partition = replica.topicPartition
-          replicaState(replica) match {
+          val currentState = controllerContext.replicaState(replica)
+
+          currentState match {
             case NewReplica =>
               val assignment = controllerContext.partitionReplicaAssignment(partition)
               if (!assignment.contains(replicaId)) {
@@ -198,20 +203,19 @@ class ReplicaStateMachine(config: KafkaConfig,
                 case None =>
               }
           }
-          logSuccessfulTransition(replicaId, partition, replicaState(replica), OnlineReplica)
-          replicaState.put(replica, OnlineReplica)
+          logSuccessfulTransition(replicaId, partition, currentState, OnlineReplica)
+          controllerContext.putReplicaState(replica, OnlineReplica)
         }
       case OfflineReplica =>
         validReplicas.foreach { replica =>
-          controllerBrokerRequestBatch.addStopReplicaRequestForBrokers(Seq(replicaId), replica.topicPartition,
-            deletePartition = false, (_, _) => ())
+          controllerBrokerRequestBatch.addStopReplicaRequestForBrokers(Seq(replicaId), replica.topicPartition, deletePartition = false)
         }
         val (replicasWithLeadershipInfo, replicasWithoutLeadershipInfo) = validReplicas.partition { replica =>
           controllerContext.partitionLeadershipInfo.contains(replica.topicPartition)
         }
         val updatedLeaderIsrAndControllerEpochs = removeReplicasFromIsr(replicaId, replicasWithLeadershipInfo.map(_.topicPartition))
         updatedLeaderIsrAndControllerEpochs.foreach { case (partition, leaderIsrAndControllerEpoch) =>
-          if (!topicDeletionManager.isTopicQueuedUpForDeletion(partition.topic)) {
+          if (!controllerContext.isTopicQueuedUpForDeletion(partition.topic)) {
             val recipients = controllerContext.partitionReplicaAssignment(partition).filterNot(_ == replicaId)
             controllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(recipients,
               partition,
@@ -219,39 +223,43 @@ class ReplicaStateMachine(config: KafkaConfig,
               controllerContext.partitionReplicaAssignment(partition), isNew = false)
           }
           val replica = PartitionAndReplica(partition, replicaId)
-          logSuccessfulTransition(replicaId, partition, replicaState(replica), OfflineReplica)
-          replicaState.put(replica, OfflineReplica)
+          val currentState = controllerContext.replicaState(replica)
+          logSuccessfulTransition(replicaId, partition, currentState, OfflineReplica)
+          controllerContext.putReplicaState(replica, OfflineReplica)
         }
 
         replicasWithoutLeadershipInfo.foreach { replica =>
-          logSuccessfulTransition(replicaId, replica.topicPartition, replicaState(replica), OfflineReplica)
-          replicaState.put(replica, OfflineReplica)
+          val currentState = controllerContext.replicaState(replica)
+          logSuccessfulTransition(replicaId, replica.topicPartition, currentState, OfflineReplica)
+          controllerContext.putReplicaState(replica, OfflineReplica)
         }
       case ReplicaDeletionStarted =>
         validReplicas.foreach { replica =>
-          logSuccessfulTransition(replicaId, replica.topicPartition, replicaState(replica), ReplicaDeletionStarted)
-          replicaState.put(replica, ReplicaDeletionStarted)
-          controllerBrokerRequestBatch.addStopReplicaRequestForBrokers(Seq(replicaId),
-            replica.topicPartition,
-            deletePartition = true,
-            callbacks.stopReplicaResponseCallback)
+          val currentState = controllerContext.replicaState(replica)
+          logSuccessfulTransition(replicaId, replica.topicPartition, currentState, ReplicaDeletionStarted)
+          controllerContext.putReplicaState(replica, ReplicaDeletionStarted)
+          val topicDeletionInProgress = controllerContext.isTopicDeletionInProgress(replica.topicPartition.topic)
+          controllerBrokerRequestBatch.addStopReplicaRequestForBrokers(Seq(replicaId), replica.topicPartition, deletePartition = true)
         }
       case ReplicaDeletionIneligible =>
         validReplicas.foreach { replica =>
-          logSuccessfulTransition(replicaId, replica.topicPartition, replicaState(replica), ReplicaDeletionIneligible)
-          replicaState.put(replica, ReplicaDeletionIneligible)
+          val currentState = controllerContext.replicaState(replica)
+          logSuccessfulTransition(replicaId, replica.topicPartition, currentState, ReplicaDeletionIneligible)
+          controllerContext.putReplicaState(replica, ReplicaDeletionIneligible)
         }
       case ReplicaDeletionSuccessful =>
         validReplicas.foreach { replica =>
-          logSuccessfulTransition(replicaId, replica.topicPartition, replicaState(replica), ReplicaDeletionSuccessful)
-          replicaState.put(replica, ReplicaDeletionSuccessful)
+          val currentState = controllerContext.replicaState(replica)
+          logSuccessfulTransition(replicaId, replica.topicPartition, currentState, ReplicaDeletionSuccessful)
+          controllerContext.putReplicaState(replica, ReplicaDeletionSuccessful)
         }
       case NonExistentReplica =>
         validReplicas.foreach { replica =>
+          val currentState = controllerContext.replicaState(replica)
           val currentAssignedReplicas = controllerContext.partitionReplicaAssignment(replica.topicPartition)
           controllerContext.updatePartitionReplicaAssignment(replica.topicPartition, currentAssignedReplicas.filterNot(_ == replica.replica))
-          logSuccessfulTransition(replicaId, replica.topicPartition, replicaState(replica), NonExistentReplica)
-          replicaState.remove(replica)
+          logSuccessfulTransition(replicaId, replica.topicPartition, currentState, NonExistentReplica)
+          controllerContext.removeReplicaState(replica)
         }
     }
   }
@@ -273,7 +281,8 @@ class ReplicaStateMachine(config: KafkaConfig,
       remaining = removalsToRetry
       failedRemovals.foreach { case (partition, e) =>
         val replica = PartitionAndReplica(partition, replicaId)
-        logFailedStateChange(replica, replicaState(replica), OfflineReplica, e)
+        val currentState = controllerContext.replicaState(replica)
+        logFailedStateChange(replica, currentState, OfflineReplica, e)
       }
     }
     results
@@ -305,7 +314,7 @@ class ReplicaStateMachine(config: KafkaConfig,
     val UpdateLeaderAndIsrResult(successfulUpdates, updatesToRetry, failedUpdates) = zkClient.updateLeaderAndIsr(
       adjustedLeaderAndIsrs, controllerContext.epoch, controllerContext.epochZkVersion)
     val exceptionsForPartitionsWithNoLeaderAndIsrInZk = partitionsWithNoLeaderAndIsrInZk.flatMap { partition =>
-      if (!topicDeletionManager.isTopicQueuedUpForDeletion(partition.topic)) {
+      if (!controllerContext.isTopicQueuedUpForDeletion(partition.topic)) {
         val exception = new StateChangeFailedException(s"Failed to change state of replica $replicaId for partition $partition since the leader and isr path in zookeeper is empty")
         Option(partition -> exception)
       } else None
@@ -367,32 +376,13 @@ class ReplicaStateMachine(config: KafkaConfig,
     (leaderAndIsrs.toMap, partitionsWithNoLeaderAndIsrInZk, failed.toMap)
   }
 
-  def isAtLeastOneReplicaInDeletionStartedState(topic: String): Boolean = {
-    controllerContext.replicasForTopic(topic).exists(replica => replicaState(replica) == ReplicaDeletionStarted)
-  }
-
-  def replicasInState(topic: String, state: ReplicaState): Set[PartitionAndReplica] = {
-    replicaState.filter { case (replica, s) => replica.topic.equals(topic) && s == state }.keySet.toSet
-  }
-
-  def areAllReplicasForTopicDeleted(topic: String): Boolean = {
-    controllerContext.replicasForTopic(topic).forall(replica => replicaState(replica) == ReplicaDeletionSuccessful)
-  }
-
-  def isAnyReplicaInState(topic: String, state: ReplicaState): Boolean = {
-    replicaState.exists { case (replica, s) => replica.topic.equals(topic) && s == state}
-  }
-
-  private def isValidTransition(replica: PartitionAndReplica, targetState: ReplicaState) =
-    targetState.validPreviousStates.contains(replicaState(replica))
-
   private def logSuccessfulTransition(replicaId: Int, partition: TopicPartition, currState: ReplicaState, targetState: ReplicaState): Unit = {
     stateChangeLogger.withControllerEpoch(controllerContext.epoch)
       .trace(s"Changed state of replica $replicaId for partition $partition from $currState to $targetState")
   }
 
   private def logInvalidTransition(replica: PartitionAndReplica, targetState: ReplicaState): Unit = {
-    val currState = replicaState(replica)
+    val currState = controllerContext.replicaState(replica)
     val e = new IllegalStateException(s"Replica $replica should be in the ${targetState.validPreviousStates.mkString(",")} " +
       s"states before moving to $targetState state. Instead it is in $currState state")
     logFailedStateChange(replica, currState, targetState, e)
@@ -437,7 +427,7 @@ case object ReplicaDeletionSuccessful extends ReplicaState {
 
 case object ReplicaDeletionIneligible extends ReplicaState {
   val state: Byte = 6
-  val validPreviousStates: Set[ReplicaState] = Set(ReplicaDeletionStarted)
+  val validPreviousStates: Set[ReplicaState] = Set(OfflineReplica, ReplicaDeletionStarted)
 }
 
 case object NonExistentReplica extends ReplicaState {
diff --git a/core/src/main/scala/kafka/controller/TopicDeletionManager.scala b/core/src/main/scala/kafka/controller/TopicDeletionManager.scala
index 1ef79be..0f56e3a 100755
--- a/core/src/main/scala/kafka/controller/TopicDeletionManager.scala
+++ b/core/src/main/scala/kafka/controller/TopicDeletionManager.scala
@@ -16,11 +16,39 @@
  */
 package kafka.controller
 
+import kafka.server.KafkaConfig
 import kafka.utils.Logging
 import kafka.zk.KafkaZkClient
 import org.apache.kafka.common.TopicPartition
 
-import scala.collection.{Set, mutable}
+import scala.collection.Set
+
+trait DeletionClient {
+  def deleteTopic(topic: String, epochZkVersion: Int): Unit
+  def deleteTopicDeletions(topics: Seq[String], epochZkVersion: Int): Unit
+  def mutePartitionModifications(topic: String): Unit
+  def sendMetadataUpdate(partitions: Set[TopicPartition]): Unit
+}
+
+class ControllerDeletionClient(controller: KafkaController, zkClient: KafkaZkClient) extends DeletionClient {
+  override def deleteTopic(topic: String, epochZkVersion: Int): Unit = {
+    zkClient.deleteTopicZNode(topic, epochZkVersion)
+    zkClient.deleteTopicConfigs(Seq(topic), epochZkVersion)
+    zkClient.deleteTopicDeletions(Seq(topic), epochZkVersion)
+  }
+
+  override def deleteTopicDeletions(topics: Seq[String], epochZkVersion: Int): Unit = {
+    zkClient.deleteTopicDeletions(topics, epochZkVersion)
+  }
+
+  override def mutePartitionModifications(topic: String): Unit = {
+    controller.unregisterPartitionModificationsHandlers(Seq(topic))
+  }
+
+  override def sendMetadataUpdate(partitions: Set[TopicPartition]): Unit = {
+    controller.sendUpdateMetadataRequest(controller.controllerContext.liveOrShuttingDownBrokerIds.toSeq, partitions)
+  }
+}
 
 /**
  * This manages the state machine for topic deletion.
@@ -55,44 +83,22 @@ import scala.collection.{Set, mutable}
  *    it marks the topic for deletion retry.
  * @param controller
  */
-class TopicDeletionManager(controller: KafkaController,
-                           eventManager: ControllerEventManager,
-                           zkClient: KafkaZkClient) extends Logging {
-  this.logIdent = s"[Topic Deletion Manager ${controller.config.brokerId}], "
-  val controllerContext = controller.controllerContext
-  val isDeleteTopicEnabled = controller.config.deleteTopicEnable
-  val topicsToBeDeleted = mutable.Set.empty[String]
-  /** The following topicsWithDeletionStarted variable is used to properly update the offlinePartitionCount metric.
-    * When a topic is going through deletion, we don't want to keep track of its partition state
-    * changes in the offlinePartitionCount metric, see the PartitionStateMachine#updateControllerMetrics
-    * for detailed logic. This goal means if some partitions of a topic are already
-    * in OfflinePartition state when deletion starts, we need to change the corresponding partition
-    * states to NonExistentPartition first before starting the deletion.
-    *
-    * However we can NOT change partition states to NonExistentPartition at the time of enqueuing topics
-    * for deletion. The reason is that when a topic is enqueued for deletion, it may be ineligible for
-    * deletion due to ongoing partition reassignments. Hence there might be a delay between enqueuing
-    * a topic for deletion and the actual start of deletion. In this delayed interval, partitions may still
-    * transition to or out of the OfflinePartition state.
-    *
-    * Hence we decide to change partition states to NonExistentPartition only when the actual deletion have started.
-    * For topics whose deletion have actually started, we keep track of them in the following topicsWithDeletionStarted
-    * variable. And once a topic is in the topicsWithDeletionStarted set, we are sure there will no longer
-    * be partition reassignments to any of its partitions, and only then it's safe to move its partitions to
-    * NonExistentPartition state. Once a topic is in the topicsWithDeletionStarted set, we will stop monitoring
-    * its partition state changes in the offlinePartitionCount metric
-    */
-  val topicsWithDeletionStarted = mutable.Set.empty[String]
-  val topicsIneligibleForDeletion = mutable.Set.empty[String]
+class TopicDeletionManager(config: KafkaConfig,
+                           controllerContext: ControllerContext,
+                           replicaStateMachine: ReplicaStateMachine,
+                           partitionStateMachine: PartitionStateMachine,
+                           client: DeletionClient) extends Logging {
+  this.logIdent = s"[Topic Deletion Manager ${config.brokerId}] "
+  val isDeleteTopicEnabled: Boolean = config.deleteTopicEnable
 
   def init(initialTopicsToBeDeleted: Set[String], initialTopicsIneligibleForDeletion: Set[String]): Unit = {
     if (isDeleteTopicEnabled) {
-      topicsToBeDeleted ++= initialTopicsToBeDeleted
-      topicsIneligibleForDeletion ++= initialTopicsIneligibleForDeletion & topicsToBeDeleted
+      controllerContext.topicsToBeDeleted ++= initialTopicsToBeDeleted
+      controllerContext.topicsIneligibleForDeletion ++= initialTopicsIneligibleForDeletion & controllerContext.topicsToBeDeleted
     } else {
       // if delete topic is disabled clean the topic entries under /admin/delete_topics
       info(s"Removing $initialTopicsToBeDeleted since delete topic is disabled")
-      zkClient.deleteTopicDeletions(initialTopicsToBeDeleted.toSeq, controllerContext.epochZkVersion)
+      client.deleteTopicDeletions(initialTopicsToBeDeleted.toSeq, controllerContext.epochZkVersion)
     }
   }
 
@@ -103,17 +109,6 @@ class TopicDeletionManager(controller: KafkaController,
   }
 
   /**
-   * Invoked when the current controller resigns. At this time, all state for topic deletion should be cleared.
-   */
-  def reset() {
-    if (isDeleteTopicEnabled) {
-      topicsToBeDeleted.clear()
-      topicsWithDeletionStarted.clear()
-      topicsIneligibleForDeletion.clear()
-    }
-  }
-
-  /**
    * Invoked by the child change listener on /admin/delete_topics to queue up the topics for deletion. The topic gets added
    * to the topicsToBeDeleted list and only gets removed from the list when the topic deletion has completed successfully
    * i.e. all replicas of all partitions of that topic are deleted successfully.
@@ -121,7 +116,7 @@ class TopicDeletionManager(controller: KafkaController,
    */
   def enqueueTopicsForDeletion(topics: Set[String]) {
     if (isDeleteTopicEnabled) {
-      topicsToBeDeleted ++= topics
+      controllerContext.topicsToBeDeleted ++= topics
       resumeDeletions()
     }
   }
@@ -134,9 +129,9 @@ class TopicDeletionManager(controller: KafkaController,
    */
   def resumeDeletionForTopics(topics: Set[String] = Set.empty) {
     if (isDeleteTopicEnabled) {
-      val topicsToResumeDeletion = topics & topicsToBeDeleted
+      val topicsToResumeDeletion = topics & controllerContext.topicsToBeDeleted
       if (topicsToResumeDeletion.nonEmpty) {
-        topicsIneligibleForDeletion --= topicsToResumeDeletion
+        controllerContext.topicsIneligibleForDeletion --= topicsToResumeDeletion
         resumeDeletions()
       }
     }
@@ -155,7 +150,7 @@ class TopicDeletionManager(controller: KafkaController,
       if (replicasThatFailedToDelete.nonEmpty) {
         val topics = replicasThatFailedToDelete.map(_.topic)
         debug(s"Deletion failed for replicas ${replicasThatFailedToDelete.mkString(",")}. Halting deletion for topics $topics")
-        controller.replicaStateMachine.handleStateChanges(replicasThatFailedToDelete.toSeq, ReplicaDeletionIneligible)
+        replicaStateMachine.handleStateChanges(replicasThatFailedToDelete.toSeq, ReplicaDeletionIneligible)
         markTopicIneligibleForDeletion(topics)
         resumeDeletions()
       }
@@ -168,10 +163,10 @@ class TopicDeletionManager(controller: KafkaController,
    * 2. partition reassignment in progress for some partitions of the topic
    * @param topics Topics that should be marked ineligible for deletion. No op if the topic is was not previously queued up for deletion
    */
-  def markTopicIneligibleForDeletion(topics: Set[String]) {
+  def markTopicIneligibleForDeletion(topics: Set[String]): Unit = {
     if (isDeleteTopicEnabled) {
-      val newTopicsToHaltDeletion = topicsToBeDeleted & topics
-      topicsIneligibleForDeletion ++= newTopicsToHaltDeletion
+      val newTopicsToHaltDeletion = controllerContext.topicsToBeDeleted & topics
+      controllerContext.topicsIneligibleForDeletion ++= newTopicsToHaltDeletion
       if (newTopicsToHaltDeletion.nonEmpty)
         info(s"Halted deletion of topics ${newTopicsToHaltDeletion.mkString(",")}")
     }
@@ -179,28 +174,21 @@ class TopicDeletionManager(controller: KafkaController,
 
   private def isTopicIneligibleForDeletion(topic: String): Boolean = {
     if (isDeleteTopicEnabled) {
-      topicsIneligibleForDeletion.contains(topic)
+      controllerContext.topicsIneligibleForDeletion.contains(topic)
     } else
       true
   }
 
   private def isTopicDeletionInProgress(topic: String): Boolean = {
     if (isDeleteTopicEnabled) {
-      controller.replicaStateMachine.isAtLeastOneReplicaInDeletionStartedState(topic)
-    } else
-      false
-  }
-
-  def isTopicWithDeletionStarted(topic: String) = {
-    if (isDeleteTopicEnabled) {
-      topicsWithDeletionStarted.contains(topic)
+      controllerContext.isAnyReplicaInState(topic, ReplicaDeletionStarted)
     } else
       false
   }
 
   def isTopicQueuedUpForDeletion(topic: String): Boolean = {
     if (isDeleteTopicEnabled) {
-      topicsToBeDeleted.contains(topic)
+      controllerContext.isTopicQueuedUpForDeletion(topic)
     } else
       false
   }
@@ -214,7 +202,7 @@ class TopicDeletionManager(controller: KafkaController,
   def completeReplicaDeletion(replicas: Set[PartitionAndReplica]) {
     val successfullyDeletedReplicas = replicas.filter(r => isTopicQueuedUpForDeletion(r.topic))
     debug(s"Deletion successfully completed for replicas ${successfullyDeletedReplicas.mkString(",")}")
-    controller.replicaStateMachine.handleStateChanges(successfullyDeletedReplicas.toSeq, ReplicaDeletionSuccessful)
+    replicaStateMachine.handleStateChanges(successfullyDeletedReplicas.toSeq, ReplicaDeletionSuccessful)
     resumeDeletions()
   }
 
@@ -227,7 +215,9 @@ class TopicDeletionManager(controller: KafkaController,
    * @return Whether or not deletion can be retried for the topic
    */
   private def isTopicEligibleForDeletion(topic: String): Boolean = {
-    topicsToBeDeleted.contains(topic) && (!isTopicDeletionInProgress(topic) && !isTopicIneligibleForDeletion(topic))
+    controllerContext.isTopicQueuedUpForDeletion(topic) &&
+      !isTopicDeletionInProgress(topic) &&
+      !isTopicIneligibleForDeletion(topic)
   }
 
   /**
@@ -235,25 +225,23 @@ class TopicDeletionManager(controller: KafkaController,
    * To ensure a successful retry, reset states for respective replicas from ReplicaDeletionIneligible to OfflineReplica state
    *@param topic Topic for which deletion should be retried
    */
-  private def markTopicForDeletionRetry(topic: String) {
+  private def retryDeletionForIneligibleReplicas(topic: String): Unit = {
     // reset replica states from ReplicaDeletionIneligible to OfflineReplica
-    val failedReplicas = controller.replicaStateMachine.replicasInState(topic, ReplicaDeletionIneligible)
+    val failedReplicas = controllerContext.replicasInState(topic, ReplicaDeletionIneligible)
     info(s"Retrying delete topic for topic $topic since replicas ${failedReplicas.mkString(",")} were not successfully deleted")
-    controller.replicaStateMachine.handleStateChanges(failedReplicas.toSeq, OfflineReplica)
+    replicaStateMachine.handleStateChanges(failedReplicas.toSeq, OfflineReplica)
   }
 
   private def completeDeleteTopic(topic: String) {
     // deregister partition change listener on the deleted topic. This is to prevent the partition change listener
     // firing before the new topic listener when a deleted topic gets auto created
-    controller.unregisterPartitionModificationsHandlers(Seq(topic))
-    val replicasForDeletedTopic = controller.replicaStateMachine.replicasInState(topic, ReplicaDeletionSuccessful)
+    client.mutePartitionModifications(topic)
+    val replicasForDeletedTopic = controllerContext.replicasInState(topic, ReplicaDeletionSuccessful)
     // controller will remove this replica from the state machine as well as its partition assignment cache
-    controller.replicaStateMachine.handleStateChanges(replicasForDeletedTopic.toSeq, NonExistentReplica)
-    topicsToBeDeleted -= topic
-    topicsWithDeletionStarted -= topic
-    zkClient.deleteTopicZNode(topic, controllerContext.epochZkVersion)
-    zkClient.deleteTopicConfigs(Seq(topic), controllerContext.epochZkVersion)
-    zkClient.deleteTopicDeletions(Seq(topic), controllerContext.epochZkVersion)
+    replicaStateMachine.handleStateChanges(replicasForDeletedTopic.toSeq, NonExistentReplica)
+    controllerContext.topicsToBeDeleted -= topic
+    controllerContext.topicsWithDeletionStarted -= topic
+    client.deleteTopic(topic, controllerContext.epochZkVersion)
     controllerContext.removeTopic(topic)
   }
 
@@ -268,17 +256,17 @@ class TopicDeletionManager(controller: KafkaController,
     info(s"Topic deletion callback for ${topics.mkString(",")}")
     // send update metadata so that brokers stop serving data for topics to be deleted
     val partitions = topics.flatMap(controllerContext.partitionsForTopic)
-    val unseenTopicsForDeletion = topics -- topicsWithDeletionStarted
+    val unseenTopicsForDeletion = topics -- controllerContext.topicsWithDeletionStarted
     if (unseenTopicsForDeletion.nonEmpty) {
       val unseenPartitionsForDeletion = unseenTopicsForDeletion.flatMap(controllerContext.partitionsForTopic)
-      controller.partitionStateMachine.handleStateChanges(unseenPartitionsForDeletion.toSeq, OfflinePartition)
-      controller.partitionStateMachine.handleStateChanges(unseenPartitionsForDeletion.toSeq, NonExistentPartition)
-      // adding of unseenTopicsForDeletion to topicsBeingDeleted must be done after the partition state changes
-      // to make sure the offlinePartitionCount metric is properly updated
-      topicsWithDeletionStarted ++= unseenTopicsForDeletion
+      partitionStateMachine.handleStateChanges(unseenPartitionsForDeletion.toSeq, OfflinePartition)
+      partitionStateMachine.handleStateChanges(unseenPartitionsForDeletion.toSeq, NonExistentPartition)
+      // adding of unseenTopicsForDeletion to topics with deletion started must be done after the partition
+      // state changes to make sure the offlinePartitionCount metric is properly updated
+      controllerContext.beginTopicDeletion(unseenTopicsForDeletion)
     }
 
-    controller.sendUpdateMetadataRequest(controllerContext.liveOrShuttingDownBrokerIds.toSeq, partitions)
+    client.sendMetadataUpdate(partitions)
     topics.foreach { topic =>
       onPartitionDeletion(controllerContext.partitionsForTopic(topic))
     }
@@ -298,22 +286,20 @@ class TopicDeletionManager(controller: KafkaController,
    * 1. Move all dead replicas directly to ReplicaDeletionIneligible state. Also mark the respective topics ineligible
    *    for deletion if some replicas are dead since it won't complete successfully anyway
    * 2. Move all alive replicas to ReplicaDeletionStarted state so they can be deleted successfully
-   *@param replicasForTopicsToBeDeleted
+   * @param replicasForTopicsToBeDeleted
    */
   private def startReplicaDeletion(replicasForTopicsToBeDeleted: Set[PartitionAndReplica]) {
     replicasForTopicsToBeDeleted.groupBy(_.topic).keys.foreach { topic =>
       val aliveReplicasForTopic = controllerContext.allLiveReplicas().filter(p => p.topic == topic)
       val deadReplicasForTopic = replicasForTopicsToBeDeleted -- aliveReplicasForTopic
-      val successfullyDeletedReplicas = controller.replicaStateMachine.replicasInState(topic, ReplicaDeletionSuccessful)
+      val successfullyDeletedReplicas = controllerContext.replicasInState(topic, ReplicaDeletionSuccessful)
       val replicasForDeletionRetry = aliveReplicasForTopic -- successfullyDeletedReplicas
       // move dead replicas directly to failed state
-      controller.replicaStateMachine.handleStateChanges(deadReplicasForTopic.toSeq, ReplicaDeletionIneligible, new Callbacks())
+      replicaStateMachine.handleStateChanges(deadReplicasForTopic.toSeq, ReplicaDeletionIneligible)
       // send stop replica to all followers that are not in the OfflineReplica state so they stop sending fetch requests to the leader
-      controller.replicaStateMachine.handleStateChanges(replicasForDeletionRetry.toSeq, OfflineReplica, new Callbacks())
+      replicaStateMachine.handleStateChanges(replicasForDeletionRetry.toSeq, OfflineReplica)
       debug(s"Deletion started for replicas ${replicasForDeletionRetry.mkString(",")}")
-      controller.replicaStateMachine.handleStateChanges(replicasForDeletionRetry.toSeq, ReplicaDeletionStarted,
-        new Callbacks(stopReplicaResponseCallback = (stopReplicaResponseObj, replicaId) =>
-          eventManager.put(controller.TopicDeletionStopReplicaResponseReceived(stopReplicaResponseObj, replicaId))))
+      replicaStateMachine.handleStateChanges(replicasForDeletionRetry.toSeq, ReplicaDeletionStarted)
       if (deadReplicasForTopic.nonEmpty) {
         debug(s"Dead Replicas (${deadReplicasForTopic.mkString(",")}) found for topic $topic")
         markTopicIneligibleForDeletion(Set(topic))
@@ -339,34 +325,31 @@ class TopicDeletionManager(controller: KafkaController,
   }
 
   private def resumeDeletions(): Unit = {
-    val topicsQueuedForDeletion = Set.empty[String] ++ topicsToBeDeleted
-
+    val topicsQueuedForDeletion = Set.empty[String] ++ controllerContext.topicsToBeDeleted
     if (topicsQueuedForDeletion.nonEmpty)
       info(s"Handling deletion for topics ${topicsQueuedForDeletion.mkString(",")}")
 
     topicsQueuedForDeletion.foreach { topic =>
       // if all replicas are marked as deleted successfully, then topic deletion is done
-      if (controller.replicaStateMachine.areAllReplicasForTopicDeleted(topic)) {
+      if (controllerContext.areAllReplicasInState(topic, ReplicaDeletionSuccessful)) {
         // clear up all state for this topic from controller cache and zookeeper
         completeDeleteTopic(topic)
         info(s"Deletion of topic $topic successfully completed")
+      } else if (controllerContext.isAnyReplicaInState(topic, ReplicaDeletionStarted)) {
+        // ignore since topic deletion is in progress
+        val replicasInDeletionStartedState = controllerContext.replicasInState(topic, ReplicaDeletionStarted)
+        val replicaIds = replicasInDeletionStartedState.map(_.replica)
+        val partitions = replicasInDeletionStartedState.map(_.topicPartition)
+        info(s"Deletion for replicas ${replicaIds.mkString(",")} for partition ${partitions.mkString(",")} of topic $topic in progress")
       } else {
-        if (controller.replicaStateMachine.isAtLeastOneReplicaInDeletionStartedState(topic)) {
-          // ignore since topic deletion is in progress
-          val replicasInDeletionStartedState = controller.replicaStateMachine.replicasInState(topic, ReplicaDeletionStarted)
-          val replicaIds = replicasInDeletionStartedState.map(_.replica)
-          val partitions = replicasInDeletionStartedState.map(_.topicPartition)
-          info(s"Deletion for replicas ${replicaIds.mkString(",")} for partition ${partitions.mkString(",")} of topic $topic in progress")
-        } else {
-          // if you come here, then no replica is in TopicDeletionStarted and all replicas are not in
-          // TopicDeletionSuccessful. That means, that either given topic haven't initiated deletion
-          // or there is at least one failed replica (which means topic deletion should be retried).
-          if (controller.replicaStateMachine.isAnyReplicaInState(topic, ReplicaDeletionIneligible)) {
-            // mark topic for deletion retry
-            markTopicForDeletionRetry(topic)
-          }
+        // if you come here, then no replica is in TopicDeletionStarted and all replicas are not in
+        // TopicDeletionSuccessful. That means, that either given topic haven't initiated deletion
+        // or there is at least one failed replica (which means topic deletion should be retried).
+        if (controllerContext.isAnyReplicaInState(topic, ReplicaDeletionIneligible)) {
+          retryDeletionForIneligibleReplicas(topic)
         }
       }
+
       // Try delete topic if it is eligible for deletion.
       if (isTopicEligibleForDeletion(topic)) {
         info(s"Deletion of topic $topic (re)started")
diff --git a/core/src/test/scala/unit/kafka/admin/DeleteTopicTest.scala b/core/src/test/scala/unit/kafka/admin/DeleteTopicTest.scala
index 400920e..7a5b3e5 100644
--- a/core/src/test/scala/unit/kafka/admin/DeleteTopicTest.scala
+++ b/core/src/test/scala/unit/kafka/admin/DeleteTopicTest.scala
@@ -201,8 +201,8 @@ class DeleteTopicTest extends ZooKeeperTestHarness {
     val (controller, controllerId) = getController()
     val allReplicasForTopic = getAllReplicasFromAssignment(topic, expectedReplicaAssignment)
     TestUtils.waitUntilTrue(() => {
-      val replicasInDeletionSuccessful = controller.kafkaController.replicaStateMachine.replicasInState(topic, ReplicaDeletionSuccessful)
-      val offlineReplicas = controller.kafkaController.replicaStateMachine.replicasInState(topic, OfflineReplica)
+      val replicasInDeletionSuccessful = controller.kafkaController.controllerContext.replicasInState(topic, ReplicaDeletionSuccessful)
+      val offlineReplicas = controller.kafkaController.controllerContext.replicasInState(topic, OfflineReplica)
       allReplicasForTopic == (replicasInDeletionSuccessful union offlineReplicas)
     }, s"Not all replicas for topic $topic are in states of either ReplicaDeletionSuccessful or OfflineReplica")
 
diff --git a/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala b/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala
index 6cfa72c..283858c 100644
--- a/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala
+++ b/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala
@@ -62,7 +62,7 @@ class ControllerFailoverTest extends KafkaServerTestHarness with Logging {
     createTopic(topic, 1, 1)
     val topicPartition = new TopicPartition("topic1", 0)
     TestUtils.waitUntilTrue(() =>
-      initialController.partitionStateMachine.partitionsInState(OnlinePartition).contains(topicPartition),
+      initialController.controllerContext.partitionsInState(OnlinePartition).contains(topicPartition),
       s"Partition $topicPartition did not transition to online state")
 
     // Wait until we have verified that we have resigned
diff --git a/core/src/test/scala/unit/kafka/controller/MockPartitionStateMachine.scala b/core/src/test/scala/unit/kafka/controller/MockPartitionStateMachine.scala
new file mode 100644
index 0000000..2578199
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/controller/MockPartitionStateMachine.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package kafka.controller
+
+import kafka.common.StateChangeFailedException
+import kafka.controller.Election._
+import org.apache.kafka.common.TopicPartition
+
+import scala.collection.mutable
+
+class MockPartitionStateMachine(controllerContext: ControllerContext,
+                                uncleanLeaderElectionEnabled: Boolean)
+  extends PartitionStateMachine(controllerContext) {
+
+  override def handleStateChanges(partitions: Seq[TopicPartition],
+                                  targetState: PartitionState,
+                                  leaderElectionStrategy: Option[PartitionLeaderElectionStrategy]): Map[TopicPartition, Throwable] = {
+    partitions.foreach(partition => controllerContext.putPartitionStateIfNotExists(partition, NonExistentPartition))
+    val (validPartitions, invalidPartitions) = controllerContext.checkValidPartitionStateChange(partitions, targetState)
+    if (invalidPartitions.nonEmpty) {
+      val currentStates = invalidPartitions.map(p => controllerContext.partitionStates.get(p))
+      throw new IllegalStateException(s"Invalid state transition to $targetState for partitions $currentStates")
+    }
+
+    if (targetState == OnlinePartition) {
+      val uninitializedPartitions = validPartitions.filter(partition => controllerContext.partitionState(partition) == NewPartition)
+      val partitionsToElectLeader = partitions.filter { partition =>
+        val currentState = controllerContext.partitionState(partition)
+        currentState == OfflinePartition || currentState == OnlinePartition
+      }
+
+      uninitializedPartitions.foreach { partition =>
+        controllerContext.putPartitionState(partition, targetState)
+      }
+
+      val failedElections = doLeaderElections(partitionsToElectLeader, leaderElectionStrategy.get)
+      val successfulElections = partitionsToElectLeader.filterNot(failedElections.keySet.contains)
+      successfulElections.foreach { partition =>
+        controllerContext.putPartitionState(partition, targetState)
+      }
+
+      failedElections
+    } else {
+      validPartitions.foreach { partition =>
+        controllerContext.putPartitionState(partition, targetState)
+      }
+      Map.empty
+    }
+  }
+
+  private def doLeaderElections(partitions: Seq[TopicPartition],
+                                leaderElectionStrategy: PartitionLeaderElectionStrategy): Map[TopicPartition, Throwable] = {
+    val failedElections = mutable.Map.empty[TopicPartition, Exception]
+    val leaderIsrAndControllerEpochPerPartition = partitions.map { partition =>
+      partition -> controllerContext.partitionLeadershipInfo(partition)
+    }
+
+    val (invalidPartitionsForElection, validPartitionsForElection) = leaderIsrAndControllerEpochPerPartition.partition { case (_, leaderIsrAndControllerEpoch) =>
+      leaderIsrAndControllerEpoch.controllerEpoch > controllerContext.epoch
+    }
+    invalidPartitionsForElection.foreach { case (partition, leaderIsrAndControllerEpoch) =>
+      val failMsg = s"aborted leader election for partition $partition since the LeaderAndIsr path was " +
+        s"already written by another controller. This probably means that the current controller went through " +
+        s"a soft failure and another controller was elected with epoch ${leaderIsrAndControllerEpoch.controllerEpoch}."
+      failedElections.put(partition, new StateChangeFailedException(failMsg))
+    }
+
+    val electionResults = leaderElectionStrategy match {
+      case OfflinePartitionLeaderElectionStrategy =>
+        val partitionsWithUncleanLeaderElectionState = validPartitionsForElection.map { case (partition, leaderIsrAndControllerEpoch) =>
+          (partition, Some(leaderIsrAndControllerEpoch), uncleanLeaderElectionEnabled)
+        }
+        leaderForOffline(controllerContext, partitionsWithUncleanLeaderElectionState)
+      case ReassignPartitionLeaderElectionStrategy =>
+        leaderForReassign(controllerContext, validPartitionsForElection)
+      case PreferredReplicaPartitionLeaderElectionStrategy =>
+        leaderForPreferredReplica(controllerContext, validPartitionsForElection)
+      case ControlledShutdownPartitionLeaderElectionStrategy =>
+        leaderForControlledShutdown(controllerContext, validPartitionsForElection)
+    }
+
+    for (electionResult <- electionResults) {
+      val partition = electionResult.topicPartition
+      electionResult.leaderAndIsr match {
+        case None =>
+          val failMsg = s"Failed to elect leader for partition $partition under strategy $leaderElectionStrategy"
+          failedElections.put(partition, new StateChangeFailedException(failMsg))
+        case Some(leaderAndIsr) =>
+          val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerContext.epoch)
+          controllerContext.partitionLeadershipInfo.put(partition, leaderIsrAndControllerEpoch)
+      }
+    }
+    failedElections.toMap
+  }
+
+}
diff --git a/core/src/test/scala/unit/kafka/controller/MockReplicaStateMachine.scala b/core/src/test/scala/unit/kafka/controller/MockReplicaStateMachine.scala
new file mode 100644
index 0000000..248a5de
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/controller/MockReplicaStateMachine.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package kafka.controller
+
+class MockReplicaStateMachine(controllerContext: ControllerContext) extends ReplicaStateMachine(controllerContext) {
+
+  override def handleStateChanges(replicas: Seq[PartitionAndReplica], targetState: ReplicaState): Unit = {
+    replicas.foreach(replica => controllerContext.putReplicaStateIfNotExists(replica, NonExistentReplica))
+    val (validReplicas, invalidReplicas) = controllerContext.checkValidReplicaStateChange(replicas, targetState)
+    if (invalidReplicas.nonEmpty) {
+      val currentStates = invalidReplicas.map(replica => replica -> controllerContext.replicaStates.get(replica)).toMap
+      throw new IllegalStateException(s"Invalid state transition to $targetState for replicas $currentStates")
+    }
+    validReplicas.foreach { replica =>
+      if (targetState == NonExistentReplica)
+        controllerContext.removeReplicaState(replica)
+      else
+        controllerContext.putReplicaState(replica, targetState)
+    }
+  }
+
+}
diff --git a/core/src/test/scala/unit/kafka/controller/PartitionStateMachineTest.scala b/core/src/test/scala/unit/kafka/controller/PartitionStateMachineTest.scala
index d711ae0..ba90231 100644
--- a/core/src/test/scala/unit/kafka/controller/PartitionStateMachineTest.scala
+++ b/core/src/test/scala/unit/kafka/controller/PartitionStateMachineTest.scala
@@ -29,16 +29,13 @@ import org.apache.zookeeper.data.Stat
 import org.easymock.EasyMock
 import org.junit.Assert._
 import org.junit.{Before, Test}
+import org.mockito.Mockito
 import org.scalatest.junit.JUnitSuite
 
-import scala.collection.mutable
-
 class PartitionStateMachineTest extends JUnitSuite {
   private var controllerContext: ControllerContext = null
   private var mockZkClient: KafkaZkClient = null
   private var mockControllerBrokerRequestBatch: ControllerBrokerRequestBatch = null
-  private var mockTopicDeletionManager: TopicDeletionManager = null
-  private var partitionState: mutable.Map[TopicPartition, PartitionState] = null
   private var partitionStateMachine: PartitionStateMachine = null
 
   private val brokerId = 5
@@ -53,11 +50,12 @@ class PartitionStateMachineTest extends JUnitSuite {
     controllerContext.epoch = controllerEpoch
     mockZkClient = EasyMock.createMock(classOf[KafkaZkClient])
     mockControllerBrokerRequestBatch = EasyMock.createMock(classOf[ControllerBrokerRequestBatch])
-    mockTopicDeletionManager = EasyMock.createMock(classOf[TopicDeletionManager])
-    partitionState = mutable.Map.empty[TopicPartition, PartitionState]
-    partitionStateMachine = new PartitionStateMachine(config, new StateChangeLogger(brokerId, true, None), controllerContext,
-      mockZkClient, partitionState, mockControllerBrokerRequestBatch)
-    partitionStateMachine.setTopicDeletionManager(mockTopicDeletionManager)
+    partitionStateMachine = new ZkPartitionStateMachine(config, new StateChangeLogger(brokerId, true, None), controllerContext,
+      mockZkClient, mockControllerBrokerRequestBatch)
+  }
+
+  private def partitionState(partition: TopicPartition): PartitionState = {
+    controllerContext.partitionState(partition)
   }
 
   @Test
@@ -82,7 +80,7 @@ class PartitionStateMachineTest extends JUnitSuite {
   def testNewPartitionToOnlinePartitionTransition(): Unit = {
     controllerContext.setLiveBrokerAndEpochs(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0)))
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
-    partitionState.put(partition, NewPartition)
+    controllerContext.putPartitionState(partition, NewPartition)
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch)
     EasyMock.expect(mockControllerBrokerRequestBatch.newBatch())
     EasyMock.expect(mockZkClient.createTopicPartitionStatesRaw(Map(partition -> leaderIsrAndControllerEpoch), controllerContext.epochZkVersion))
@@ -100,7 +98,7 @@ class PartitionStateMachineTest extends JUnitSuite {
   def testNewPartitionToOnlinePartitionTransitionZkUtilsExceptionFromCreateStates(): Unit = {
     controllerContext.setLiveBrokerAndEpochs(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0)))
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
-    partitionState.put(partition, NewPartition)
+    controllerContext.putPartitionState(partition, NewPartition)
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch)
     EasyMock.expect(mockControllerBrokerRequestBatch.newBatch())
     EasyMock.expect(mockZkClient.createTopicPartitionStatesRaw(Map(partition -> leaderIsrAndControllerEpoch), controllerContext.epochZkVersion))
@@ -116,7 +114,7 @@ class PartitionStateMachineTest extends JUnitSuite {
   def testNewPartitionToOnlinePartitionTransitionErrorCodeFromCreateStates(): Unit = {
     controllerContext.setLiveBrokerAndEpochs(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0)))
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
-    partitionState.put(partition, NewPartition)
+    controllerContext.putPartitionState(partition, NewPartition)
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch)
     EasyMock.expect(mockControllerBrokerRequestBatch.newBatch())
     EasyMock.expect(mockZkClient.createTopicPartitionStatesRaw(Map(partition -> leaderIsrAndControllerEpoch), controllerContext.epochZkVersion))
@@ -130,14 +128,14 @@ class PartitionStateMachineTest extends JUnitSuite {
 
   @Test
   def testNewPartitionToOfflinePartitionTransition(): Unit = {
-    partitionState.put(partition, NewPartition)
+    controllerContext.putPartitionState(partition, NewPartition)
     partitionStateMachine.handleStateChanges(partitions, OfflinePartition)
     assertEquals(OfflinePartition, partitionState(partition))
   }
 
   @Test
   def testInvalidNewPartitionToNonexistentPartitionTransition(): Unit = {
-    partitionState.put(partition, NewPartition)
+    controllerContext.putPartitionState(partition, NewPartition)
     partitionStateMachine.handleStateChanges(partitions, NonExistentPartition)
     assertEquals(NewPartition, partitionState(partition))
   }
@@ -146,7 +144,7 @@ class PartitionStateMachineTest extends JUnitSuite {
   def testOnlinePartitionToOnlineTransition(): Unit = {
     controllerContext.setLiveBrokerAndEpochs(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0)))
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
-    partitionState.put(partition, OnlinePartition)
+    controllerContext.putPartitionState(partition, OnlinePartition)
     val leaderAndIsr = LeaderAndIsr(brokerId, List(brokerId))
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)
     controllerContext.partitionLeadershipInfo.put(partition, leaderIsrAndControllerEpoch)
@@ -179,7 +177,7 @@ class PartitionStateMachineTest extends JUnitSuite {
       TestUtils.createBrokerAndEpoch(otherBrokerId, "host", 0)))
     controllerContext.shuttingDownBrokerIds.add(brokerId)
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId, otherBrokerId))
-    partitionState.put(partition, OnlinePartition)
+    controllerContext.putPartitionState(partition, OnlinePartition)
     val leaderAndIsr = LeaderAndIsr(brokerId, List(brokerId, otherBrokerId))
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)
     controllerContext.partitionLeadershipInfo.put(partition, leaderIsrAndControllerEpoch)
@@ -209,21 +207,21 @@ class PartitionStateMachineTest extends JUnitSuite {
 
   @Test
   def testOnlinePartitionToOfflineTransition(): Unit = {
-    partitionState.put(partition, OnlinePartition)
+    controllerContext.putPartitionState(partition, OnlinePartition)
     partitionStateMachine.handleStateChanges(partitions, OfflinePartition)
     assertEquals(OfflinePartition, partitionState(partition))
   }
 
   @Test
   def testInvalidOnlinePartitionToNonexistentPartitionTransition(): Unit = {
-    partitionState.put(partition, OnlinePartition)
+    controllerContext.putPartitionState(partition, OnlinePartition)
     partitionStateMachine.handleStateChanges(partitions, NonExistentPartition)
     assertEquals(OnlinePartition, partitionState(partition))
   }
 
   @Test
   def testInvalidOnlinePartitionToNewPartitionTransition(): Unit = {
-    partitionState.put(partition, OnlinePartition)
+    controllerContext.putPartitionState(partition, OnlinePartition)
     partitionStateMachine.handleStateChanges(partitions, NewPartition)
     assertEquals(OnlinePartition, partitionState(partition))
   }
@@ -232,7 +230,7 @@ class PartitionStateMachineTest extends JUnitSuite {
   def testOfflinePartitionToOnlinePartitionTransition(): Unit = {
     controllerContext.setLiveBrokerAndEpochs(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0)))
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
-    partitionState.put(partition, OfflinePartition)
+    controllerContext.putPartitionState(partition, OfflinePartition)
     val leaderAndIsr = LeaderAndIsr(LeaderAndIsr.NoLeader, List(brokerId))
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)
     controllerContext.partitionLeadershipInfo.put(partition, leaderIsrAndControllerEpoch)
@@ -263,7 +261,7 @@ class PartitionStateMachineTest extends JUnitSuite {
   def testOfflinePartitionToOnlinePartitionTransitionZkUtilsExceptionFromStateLookup(): Unit = {
     controllerContext.setLiveBrokerAndEpochs(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0)))
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
-    partitionState.put(partition, OfflinePartition)
+    controllerContext.putPartitionState(partition, OfflinePartition)
     val leaderAndIsr = LeaderAndIsr(LeaderAndIsr.NoLeader, List(brokerId))
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)
     controllerContext.partitionLeadershipInfo.put(partition, leaderIsrAndControllerEpoch)
@@ -284,7 +282,7 @@ class PartitionStateMachineTest extends JUnitSuite {
   def testOfflinePartitionToOnlinePartitionTransitionErrorCodeFromStateLookup(): Unit = {
     controllerContext.setLiveBrokerAndEpochs(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0)))
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
-    partitionState.put(partition, OfflinePartition)
+    controllerContext.putPartitionState(partition, OfflinePartition)
     val leaderAndIsr = LeaderAndIsr(LeaderAndIsr.NoLeader, List(brokerId))
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)
     controllerContext.partitionLeadershipInfo.put(partition, leaderIsrAndControllerEpoch)
@@ -305,14 +303,14 @@ class PartitionStateMachineTest extends JUnitSuite {
 
   @Test
   def testOfflinePartitionToNonexistentPartitionTransition(): Unit = {
-    partitionState.put(partition, OfflinePartition)
+    controllerContext.putPartitionState(partition, OfflinePartition)
     partitionStateMachine.handleStateChanges(partitions, NonExistentPartition)
     assertEquals(NonExistentPartition, partitionState(partition))
   }
 
   @Test
   def testInvalidOfflinePartitionToNewPartitionTransition(): Unit = {
-    partitionState.put(partition, OfflinePartition)
+    controllerContext.putPartitionState(partition, OfflinePartition)
     partitionStateMachine.handleStateChanges(partitions, NewPartition)
     assertEquals(OfflinePartition, partitionState(partition))
   }
@@ -356,23 +354,21 @@ class PartitionStateMachineTest extends JUnitSuite {
 
     val partitionIds = Seq(0, 1, 2, 3)
     val topic = "test"
-    val partitions = partitionIds.map(new TopicPartition("test", _))
+    val partitions = partitionIds.map(new TopicPartition(topic, _))
 
     partitions.foreach { partition =>
       controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
     }
 
-    EasyMock.expect(mockTopicDeletionManager.isTopicWithDeletionStarted(topic)).andReturn(false)
-    EasyMock.expectLastCall().anyTimes()
     prepareMockToElectLeaderForPartitions(partitions)
-    EasyMock.replay(mockZkClient, mockTopicDeletionManager)
+    EasyMock.replay(mockZkClient)
 
     partitionStateMachine.handleStateChanges(partitions, NewPartition)
     partitionStateMachine.handleStateChanges(partitions, OfflinePartition)
-    assertEquals(s"There should be ${partitions.size} offline partition(s)", partitions.size, partitionStateMachine.offlinePartitionCount)
+    assertEquals(s"There should be ${partitions.size} offline partition(s)", partitions.size, controllerContext.offlinePartitionCount)
 
     partitionStateMachine.handleStateChanges(partitions, OnlinePartition, Some(OfflinePartitionLeaderElectionStrategy))
-    assertEquals(s"There should be no offline partition(s)", 0, partitionStateMachine.offlinePartitionCount)
+    assertEquals(s"There should be no offline partition(s)", 0, controllerContext.offlinePartitionCount)
   }
 
   /**
@@ -383,15 +379,14 @@ class PartitionStateMachineTest extends JUnitSuite {
   def testNoOfflinePartitionsChangeForTopicsBeingDeleted() = {
     val partitionIds = Seq(0, 1, 2, 3)
     val topic = "test"
-    val partitions = partitionIds.map(new TopicPartition("test", _))
+    val partitions = partitionIds.map(new TopicPartition(topic, _))
 
-    EasyMock.expect(mockTopicDeletionManager.isTopicWithDeletionStarted(topic)).andReturn(true)
-    EasyMock.expectLastCall().anyTimes()
-    EasyMock.replay(mockTopicDeletionManager)
+    controllerContext.topicsToBeDeleted.add(topic)
+    controllerContext.topicsWithDeletionStarted.add(topic)
 
     partitionStateMachine.handleStateChanges(partitions, NewPartition)
     partitionStateMachine.handleStateChanges(partitions, OfflinePartition)
-    assertEquals(s"There should be no offline partition(s)", 0, partitionStateMachine.offlinePartitionCount)
+    assertEquals(s"There should be no offline partition(s)", 0, controllerContext.offlinePartitionCount)
   }
 
   /**
@@ -411,52 +406,21 @@ class PartitionStateMachineTest extends JUnitSuite {
       controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
     }
 
-    val props = TestUtils.createBrokerConfig(brokerId, "zkConnect")
-    props.put(KafkaConfig.DeleteTopicEnableProp, "true")
-
-    val customConfig = KafkaConfig.fromProps(props)
-
-    def createMockReplicaStateMachine() = {
-      val replicaStateMachine: ReplicaStateMachine = EasyMock.createMock(classOf[ReplicaStateMachine])
-      EasyMock.expect(replicaStateMachine.areAllReplicasForTopicDeleted(topic)).andReturn(false).anyTimes()
-      EasyMock.expect(replicaStateMachine.isAtLeastOneReplicaInDeletionStartedState(topic)).andReturn(false).anyTimes()
-      EasyMock.expect(replicaStateMachine.isAnyReplicaInState(topic, ReplicaDeletionIneligible)).andReturn(false).anyTimes()
-      EasyMock.expect(replicaStateMachine.replicasInState(topic, ReplicaDeletionIneligible)).andReturn(Set.empty).anyTimes()
-      EasyMock.expect(replicaStateMachine.replicasInState(topic, ReplicaDeletionStarted)).andReturn(Set.empty).anyTimes()
-      EasyMock.expect(replicaStateMachine.replicasInState(topic, ReplicaDeletionSuccessful)).andReturn(Set.empty).anyTimes()
-      EasyMock.expect(replicaStateMachine.handleStateChanges(EasyMock.anyObject[Seq[PartitionAndReplica]],
-        EasyMock.anyObject[ReplicaState], EasyMock.anyObject[Callbacks]))
-
-      EasyMock.expectLastCall().anyTimes()
-      replicaStateMachine
-    }
-    val replicaStateMachine = createMockReplicaStateMachine()
-    partitionStateMachine = new PartitionStateMachine(customConfig, new StateChangeLogger(brokerId, true, None), controllerContext,
-      mockZkClient, partitionState, mockControllerBrokerRequestBatch)
-
-    def createMockController() = {
-      val mockController: KafkaController = EasyMock.createMock(classOf[KafkaController])
-      EasyMock.expect(mockController.controllerContext).andReturn(controllerContext).anyTimes()
-      EasyMock.expect(mockController.config).andReturn(customConfig).anyTimes()
-      EasyMock.expect(mockController.partitionStateMachine).andReturn(partitionStateMachine).anyTimes()
-      EasyMock.expect(mockController.replicaStateMachine).andReturn(replicaStateMachine).anyTimes()
-      EasyMock.expect(mockController.sendUpdateMetadataRequest(Seq.empty, partitions.toSet))
-      EasyMock.expectLastCall().anyTimes()
-      mockController
-    }
-
-    val mockController = createMockController()
-    val mockEventManager: ControllerEventManager = EasyMock.createMock(classOf[ControllerEventManager])
-    EasyMock.replay(mockController, replicaStateMachine, mockEventManager)
-
-    val topicDeletionManager = new TopicDeletionManager(mockController, mockEventManager, mockZkClient)
-    partitionStateMachine.setTopicDeletionManager(topicDeletionManager)
+    val partitionStateMachine = new MockPartitionStateMachine(controllerContext, uncleanLeaderElectionEnabled = false)
+    val replicaStateMachine = new MockReplicaStateMachine(controllerContext)
+    val deletionClient = Mockito.mock(classOf[DeletionClient])
+    val topicDeletionManager = new TopicDeletionManager(config, controllerContext,
+      replicaStateMachine, partitionStateMachine, deletionClient)
 
     partitionStateMachine.handleStateChanges(partitions, NewPartition)
     partitionStateMachine.handleStateChanges(partitions, OfflinePartition)
-    assertEquals(s"There should be ${partitions.size} offline partition(s)", partitions.size, mockController.partitionStateMachine.offlinePartitionCount)
+    partitions.foreach { partition =>
+      val replica = PartitionAndReplica(partition, brokerId)
+      controllerContext.putReplicaState(replica, OfflineReplica)
+    }
 
+    assertEquals(s"There should be ${partitions.size} offline partition(s)", partitions.size, controllerContext.offlinePartitionCount)
     topicDeletionManager.enqueueTopicsForDeletion(Set(topic))
-    assertEquals(s"There should be no offline partition(s)", 0, partitionStateMachine.offlinePartitionCount)
+    assertEquals(s"There should be no offline partition(s)", 0, controllerContext.offlinePartitionCount)
   }
 }
diff --git a/core/src/test/scala/unit/kafka/controller/ReplicaStateMachineTest.scala b/core/src/test/scala/unit/kafka/controller/ReplicaStateMachineTest.scala
index ef274fa..cfadfbe 100644
--- a/core/src/test/scala/unit/kafka/controller/ReplicaStateMachineTest.scala
+++ b/core/src/test/scala/unit/kafka/controller/ReplicaStateMachineTest.scala
@@ -30,14 +30,10 @@ import org.junit.Assert._
 import org.junit.{Before, Test}
 import org.scalatest.junit.JUnitSuite
 
-import scala.collection.mutable
-
 class ReplicaStateMachineTest extends JUnitSuite {
   private var controllerContext: ControllerContext = null
   private var mockZkClient: KafkaZkClient = null
   private var mockControllerBrokerRequestBatch: ControllerBrokerRequestBatch = null
-  private var mockTopicDeletionManager: TopicDeletionManager = null
-  private var replicaState: mutable.Map[PartitionAndReplica, ReplicaState] = null
   private var replicaStateMachine: ReplicaStateMachine = null
 
   private val brokerId = 5
@@ -54,10 +50,12 @@ class ReplicaStateMachineTest extends JUnitSuite {
     controllerContext.epoch = controllerEpoch
     mockZkClient = EasyMock.createMock(classOf[KafkaZkClient])
     mockControllerBrokerRequestBatch = EasyMock.createMock(classOf[ControllerBrokerRequestBatch])
-    mockTopicDeletionManager = EasyMock.createMock(classOf[TopicDeletionManager])
-    replicaState = mutable.Map.empty[PartitionAndReplica, ReplicaState]
-    replicaStateMachine = new ReplicaStateMachine(config, new StateChangeLogger(brokerId, true, None), controllerContext, mockTopicDeletionManager, mockZkClient,
-      replicaState, mockControllerBrokerRequestBatch)
+    replicaStateMachine = new ZkReplicaStateMachine(config, new StateChangeLogger(brokerId, true, None),
+      controllerContext, mockZkClient, mockControllerBrokerRequestBatch)
+  }
+
+  private def replicaState(replica: PartitionAndReplica): ReplicaState = {
+    controllerContext.replicaState(replica)
   }
 
   @Test
@@ -103,7 +101,7 @@ class ReplicaStateMachineTest extends JUnitSuite {
 
   @Test
   def testNewReplicaToOnlineReplicaTransition(): Unit = {
-    replicaState.put(replica, NewReplica)
+    controllerContext.putReplicaState(replica, NewReplica)
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
     replicaStateMachine.handleStateChanges(replicas, OnlineReplica)
     assertEquals(OnlineReplica, replicaState(replica))
@@ -111,10 +109,9 @@ class ReplicaStateMachineTest extends JUnitSuite {
 
   @Test
   def testNewReplicaToOfflineReplicaTransition(): Unit = {
-    replicaState.put(replica, NewReplica)
+    controllerContext.putReplicaState(replica, NewReplica)
     EasyMock.expect(mockControllerBrokerRequestBatch.newBatch())
-    EasyMock.expect(mockControllerBrokerRequestBatch.addStopReplicaRequestForBrokers(EasyMock.eq(Seq(brokerId)),
-      EasyMock.eq(partition), EasyMock.eq(false), EasyMock.anyObject()))
+    EasyMock.expect(mockControllerBrokerRequestBatch.addStopReplicaRequestForBrokers(EasyMock.eq(Seq(brokerId)), EasyMock.eq(partition), EasyMock.eq(false)))
     EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch))
     EasyMock.replay(mockControllerBrokerRequestBatch)
     replicaStateMachine.handleStateChanges(replicas, OfflineReplica)
@@ -149,7 +146,7 @@ class ReplicaStateMachineTest extends JUnitSuite {
 
   @Test
   def testOnlineReplicaToOnlineReplicaTransition(): Unit = {
-    replicaState.put(replica, OnlineReplica)
+    controllerContext.putReplicaState(replica, OnlineReplica)
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch)
     controllerContext.partitionLeadershipInfo.put(partition, leaderIsrAndControllerEpoch)
@@ -167,7 +164,7 @@ class ReplicaStateMachineTest extends JUnitSuite {
   def testOnlineReplicaToOfflineReplicaTransition(): Unit = {
     val otherBrokerId = brokerId + 1
     val replicaIds = List(brokerId, otherBrokerId)
-    replicaState.put(replica, OnlineReplica)
+    controllerContext.putReplicaState(replica, OnlineReplica)
     controllerContext.updatePartitionReplicaAssignment(partition, replicaIds)
     val leaderAndIsr = LeaderAndIsr(brokerId, replicaIds)
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)
@@ -175,8 +172,7 @@ class ReplicaStateMachineTest extends JUnitSuite {
 
     val stat = new Stat(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
     EasyMock.expect(mockControllerBrokerRequestBatch.newBatch())
-    EasyMock.expect(mockControllerBrokerRequestBatch.addStopReplicaRequestForBrokers(EasyMock.eq(Seq(brokerId)),
-      EasyMock.eq(partition), EasyMock.eq(false), EasyMock.anyObject()))
+    EasyMock.expect(mockControllerBrokerRequestBatch.addStopReplicaRequestForBrokers(EasyMock.eq(Seq(brokerId)), EasyMock.eq(partition), EasyMock.eq(false)))
     val adjustedLeaderAndIsr = leaderAndIsr.newLeaderAndIsr(LeaderAndIsr.NoLeader, List(otherBrokerId))
     val updatedLeaderAndIsr = adjustedLeaderAndIsr.withZkVersion(adjustedLeaderAndIsr .zkVersion + 1)
     val updatedLeaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(updatedLeaderAndIsr, controllerEpoch)
@@ -185,14 +181,13 @@ class ReplicaStateMachineTest extends JUnitSuite {
         TopicPartitionStateZNode.encode(leaderIsrAndControllerEpoch), stat, ResponseMetadata(0, 0))))
     EasyMock.expect(mockZkClient.updateLeaderAndIsr(Map(partition -> adjustedLeaderAndIsr), controllerEpoch, controllerContext.epochZkVersion))
       .andReturn(UpdateLeaderAndIsrResult(Map(partition -> updatedLeaderAndIsr), Seq.empty, Map.empty))
-    EasyMock.expect(mockTopicDeletionManager.isTopicQueuedUpForDeletion(partition.topic)).andReturn(false)
     EasyMock.expect(mockControllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(otherBrokerId),
       partition, updatedLeaderIsrAndControllerEpoch, replicaIds, isNew = false))
     EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch))
 
-    EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch, mockTopicDeletionManager)
+    EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch)
     replicaStateMachine.handleStateChanges(replicas, OfflineReplica)
-    EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch, mockTopicDeletionManager)
+    EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch)
     assertEquals(updatedLeaderIsrAndControllerEpoch, controllerContext.partitionLeadershipInfo(partition))
     assertEquals(OfflineReplica, replicaState(replica))
   }
@@ -224,7 +219,7 @@ class ReplicaStateMachineTest extends JUnitSuite {
 
   @Test
   def testOfflineReplicaToOnlineReplicaTransition(): Unit = {
-    replicaState.put(replica, OfflineReplica)
+    controllerContext.putReplicaState(replica, OfflineReplica)
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch)
     controllerContext.partitionLeadershipInfo.put(partition, leaderIsrAndControllerEpoch)
@@ -240,21 +235,21 @@ class ReplicaStateMachineTest extends JUnitSuite {
 
   @Test
   def testOfflineReplicaToReplicaDeletionStartedTransition(): Unit = {
-    val callbacks = new Callbacks()
-    replicaState.put(replica, OfflineReplica)
+    controllerContext.putReplicaState(replica, OfflineReplica)
     EasyMock.expect(mockControllerBrokerRequestBatch.newBatch())
-    EasyMock.expect(mockControllerBrokerRequestBatch.addStopReplicaRequestForBrokers(Seq(brokerId),
-      partition, true, callbacks.stopReplicaResponseCallback))
+    EasyMock.expect(mockControllerBrokerRequestBatch.addStopReplicaRequestForBrokers(Seq(brokerId), partition, true))
     EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch))
     EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch)
-    replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionStarted, callbacks)
+    replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionStarted)
     EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch)
     assertEquals(ReplicaDeletionStarted, replicaState(replica))
   }
 
   @Test
-  def testInvalidOfflineReplicaToReplicaDeletionIneligibleTransition(): Unit = {
-    testInvalidTransition(OfflineReplica, ReplicaDeletionIneligible)
+  def testOfflineReplicaToReplicaDeletionIneligibleTransition(): Unit = {
+    controllerContext.putReplicaState(replica, OfflineReplica)
+    replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionIneligible)
+    assertEquals(ReplicaDeletionIneligible, replicaState(replica))
   }
 
   @Test
@@ -284,25 +279,25 @@ class ReplicaStateMachineTest extends JUnitSuite {
 
   @Test
   def testReplicaDeletionStartedToReplicaDeletionIneligibleTransition(): Unit = {
-    replicaState.put(replica, ReplicaDeletionStarted)
+    controllerContext.putReplicaState(replica, ReplicaDeletionStarted)
     replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionIneligible)
     assertEquals(ReplicaDeletionIneligible, replicaState(replica))
   }
 
   @Test
   def testReplicaDeletionStartedToReplicaDeletionSuccessfulTransition(): Unit = {
-    replicaState.put(replica, ReplicaDeletionStarted)
+    controllerContext.putReplicaState(replica, ReplicaDeletionStarted)
     replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionSuccessful)
     assertEquals(ReplicaDeletionSuccessful, replicaState(replica))
   }
 
   @Test
   def testReplicaDeletionSuccessfulToNonexistentReplicaTransition(): Unit = {
-    replicaState.put(replica, ReplicaDeletionSuccessful)
+    controllerContext.putReplicaState(replica, ReplicaDeletionSuccessful)
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
     replicaStateMachine.handleStateChanges(replicas, NonExistentReplica)
     assertEquals(Seq.empty, controllerContext.partitionReplicaAssignment(partition))
-    assertEquals(None, replicaState.get(replica))
+    assertEquals(None, controllerContext.replicaStates.get(replica))
   }
 
   @Test
@@ -342,7 +337,7 @@ class ReplicaStateMachineTest extends JUnitSuite {
 
   @Test
   def testReplicaDeletionIneligibleToOnlineReplicaTransition(): Unit = {
-    replicaState.put(replica, ReplicaDeletionIneligible)
+    controllerContext.putReplicaState(replica, ReplicaDeletionIneligible)
     controllerContext.updatePartitionReplicaAssignment(partition, Seq(brokerId))
     val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch)
     controllerContext.partitionLeadershipInfo.put(partition, leaderIsrAndControllerEpoch)
@@ -367,7 +362,7 @@ class ReplicaStateMachineTest extends JUnitSuite {
   }
 
   private def testInvalidTransition(fromState: ReplicaState, toState: ReplicaState): Unit = {
-    replicaState.put(replica, fromState)
+    controllerContext.putReplicaState(replica, fromState)
     replicaStateMachine.handleStateChanges(replicas, toState)
     assertEquals(fromState, replicaState(replica))
   }
diff --git a/core/src/test/scala/unit/kafka/controller/TopicDeletionManagerTest.scala b/core/src/test/scala/unit/kafka/controller/TopicDeletionManagerTest.scala
new file mode 100644
index 0000000..e6297c0
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/controller/TopicDeletionManagerTest.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package kafka.controller
+
+import kafka.cluster.{Broker, EndPoint}
+import kafka.server.KafkaConfig
+import kafka.utils.TestUtils
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.network.ListenerName
+import org.apache.kafka.common.security.auth.SecurityProtocol
+import org.junit.Assert._
+import org.junit.Test
+import org.mockito.Mockito._
+
+class TopicDeletionManagerTest {
+
+  private val brokerId = 1
+  private val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(brokerId, "zkConnect"))
+  private val deletionClient = mock(classOf[DeletionClient])
+
+  @Test
+  def testBasicDeletion(): Unit = {
+    val controllerContext = initContext(
+      brokers = Seq(1, 2, 3),
+      topics = Set("foo", "bar"),
+      numPartitions = 2,
+      replicationFactor = 3)
+    val replicaStateMachine = new MockReplicaStateMachine(controllerContext)
+    replicaStateMachine.startup()
+
+    val partitionStateMachine = new MockPartitionStateMachine(controllerContext, uncleanLeaderElectionEnabled = false)
+    partitionStateMachine.startup()
+
+    val deletionManager = new TopicDeletionManager(config, controllerContext, replicaStateMachine,
+      partitionStateMachine, deletionClient)
+    assertTrue(deletionManager.isDeleteTopicEnabled)
+    deletionManager.init(Set.empty, Set.empty)
+
+    val fooPartitions = controllerContext.partitionsForTopic("foo")
+    val fooReplicas = controllerContext.replicasForPartition(fooPartitions).toSet
+
+    // Queue the topic for deletion
+    deletionManager.enqueueTopicsForDeletion(Set("foo"))
+
+    assertEquals(fooPartitions, controllerContext.partitionsInState("foo", NonExistentPartition))
+    assertEquals(fooReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted))
+    verify(deletionClient).sendMetadataUpdate(fooPartitions)
+    assertEquals(Set("foo"), controllerContext.topicsToBeDeleted)
+    assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted)
+    assertEquals(Set(), controllerContext.topicsIneligibleForDeletion)
+
+    // Complete the deletion
+    deletionManager.completeReplicaDeletion(fooReplicas)
+
+    assertEquals(Set.empty, controllerContext.partitionsForTopic("foo"))
+    assertEquals(Set.empty[PartitionAndReplica], controllerContext.replicaStates.keySet.filter(_.topic == "foo"))
+    assertEquals(Set(), controllerContext.topicsToBeDeleted)
+    assertEquals(Set(), controllerContext.topicsWithDeletionStarted)
+    assertEquals(Set(), controllerContext.topicsIneligibleForDeletion)
+  }
+
+  @Test
+  def testDeletionWithBrokerOffline(): Unit = {
+    val controllerContext = initContext(
+      brokers = Seq(1, 2, 3),
+      topics = Set("foo", "bar"),
+      numPartitions = 2,
+      replicationFactor = 3)
+
+    val replicaStateMachine = new MockReplicaStateMachine(controllerContext)
+    replicaStateMachine.startup()
+
+    val partitionStateMachine = new MockPartitionStateMachine(controllerContext, uncleanLeaderElectionEnabled = false)
+    partitionStateMachine.startup()
+
+    val deletionManager = new TopicDeletionManager(config, controllerContext, replicaStateMachine,
+      partitionStateMachine, deletionClient)
+    assertTrue(deletionManager.isDeleteTopicEnabled)
+    deletionManager.init(Set.empty, Set.empty)
+
+    val fooPartitions = controllerContext.partitionsForTopic("foo")
+    val fooReplicas = controllerContext.replicasForPartition(fooPartitions).toSet
+
+    // Broker 2 is taken offline
+    val failedBrokerId = 2
+    val offlineBroker = controllerContext.liveOrShuttingDownBroker(failedBrokerId).get
+    val lastEpoch = controllerContext.liveBrokerIdAndEpochs(failedBrokerId)
+    controllerContext.removeLiveBrokers(Set(failedBrokerId))
+    assertEquals(Set(1, 3), controllerContext.liveBrokerIds)
+
+    val (offlineReplicas, onlineReplicas) = fooReplicas.partition(_.replica == failedBrokerId)
+    replicaStateMachine.handleStateChanges(offlineReplicas.toSeq, OfflineReplica)
+
+    // Start topic deletion
+    deletionManager.enqueueTopicsForDeletion(Set("foo"))
+    assertEquals(fooPartitions, controllerContext.partitionsInState("foo", NonExistentPartition))
+    verify(deletionClient).sendMetadataUpdate(fooPartitions)
+    assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted))
+    assertEquals(offlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionIneligible))
+
+    assertEquals(Set("foo"), controllerContext.topicsToBeDeleted)
+    assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted)
+    assertEquals(Set("foo"), controllerContext.topicsIneligibleForDeletion)
+
+    // Deletion succeeds for online replicas
+    deletionManager.completeReplicaDeletion(onlineReplicas)
+
+    assertEquals(fooPartitions, controllerContext.partitionsInState("foo", NonExistentPartition))
+    assertEquals(Set("foo"), controllerContext.topicsToBeDeleted)
+    assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted)
+    assertEquals(Set("foo"), controllerContext.topicsIneligibleForDeletion)
+    assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionSuccessful))
+    assertEquals(offlineReplicas, controllerContext.replicasInState("foo", OfflineReplica))
+
+    // Broker 2 comes back online and deletion is resumed
+    controllerContext.addLiveBrokersAndEpochs(Map(offlineBroker -> (lastEpoch + 1L)))
+    deletionManager.resumeDeletionForTopics(Set("foo"))
+
+    assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionSuccessful))
+    assertEquals(offlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted))
+
+    deletionManager.completeReplicaDeletion(offlineReplicas)
+    assertEquals(Set.empty, controllerContext.partitionsForTopic("foo"))
+    assertEquals(Set.empty[PartitionAndReplica], controllerContext.replicaStates.keySet.filter(_.topic == "foo"))
+    assertEquals(Set(), controllerContext.topicsToBeDeleted)
+    assertEquals(Set(), controllerContext.topicsWithDeletionStarted)
+    assertEquals(Set(), controllerContext.topicsIneligibleForDeletion)
+  }
+
+  @Test
+  def testBrokerFailureAfterDeletionStarted(): Unit = {
+    val controllerContext = initContext(
+      brokers = Seq(1, 2, 3),
+      topics = Set("foo", "bar"),
+      numPartitions = 2,
+      replicationFactor = 3)
+
+    val replicaStateMachine = new MockReplicaStateMachine(controllerContext)
+    replicaStateMachine.startup()
+
+    val partitionStateMachine = new MockPartitionStateMachine(controllerContext, uncleanLeaderElectionEnabled = false)
+    partitionStateMachine.startup()
+
+    val deletionManager = new TopicDeletionManager(config, controllerContext, replicaStateMachine,
+      partitionStateMachine, deletionClient)
+    deletionManager.init(Set.empty, Set.empty)
+
+    val fooPartitions = controllerContext.partitionsForTopic("foo")
+    val fooReplicas = controllerContext.replicasForPartition(fooPartitions).toSet
+
+    // Queue the topic for deletion
+    deletionManager.enqueueTopicsForDeletion(Set("foo"))
+    assertEquals(fooPartitions, controllerContext.partitionsInState("foo", NonExistentPartition))
+    assertEquals(fooReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted))
+
+    // Broker 2 fails
+    val failedBrokerId = 2
+    val offlineBroker = controllerContext.liveOrShuttingDownBroker(failedBrokerId).get
+    val lastEpoch = controllerContext.liveBrokerIdAndEpochs(failedBrokerId)
+    controllerContext.removeLiveBrokers(Set(failedBrokerId))
+    assertEquals(Set(1, 3), controllerContext.liveBrokerIds)
+    val (offlineReplicas, onlineReplicas) = fooReplicas.partition(_.replica == failedBrokerId)
+
+    // Fail replica deletion
+    deletionManager.failReplicaDeletion(offlineReplicas)
+    assertEquals(Set("foo"), controllerContext.topicsToBeDeleted)
+    assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted)
+    assertEquals(Set("foo"), controllerContext.topicsIneligibleForDeletion)
+    assertEquals(offlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionIneligible))
+    assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted))
+
+    // Broker 2 is restarted. The offline replicas remain ineligable
+    // (TODO: this is probably not desired)
+    controllerContext.addLiveBrokersAndEpochs(Map(offlineBroker -> (lastEpoch + 1L)))
+    deletionManager.resumeDeletionForTopics(Set("foo"))
+    assertEquals(Set("foo"), controllerContext.topicsToBeDeleted)
+    assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted)
+    assertEquals(Set(), controllerContext.topicsIneligibleForDeletion)
+    assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted))
+    assertEquals(offlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionIneligible))
+
+    // When deletion completes for the replicas which started, then deletion begins for the remaining ones
+    deletionManager.completeReplicaDeletion(onlineReplicas)
+    assertEquals(Set("foo"), controllerContext.topicsToBeDeleted)
+    assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted)
+    assertEquals(Set(), controllerContext.topicsIneligibleForDeletion)
+    assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionSuccessful))
+    assertEquals(offlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted))
+
+  }
+
+  def initContext(brokers: Seq[Int],
+                  topics: Set[String],
+                  numPartitions: Int,
+                  replicationFactor: Int): ControllerContext = {
+    val context = new ControllerContext
+    val brokerEpochs = brokers.map { brokerId =>
+      val endpoint = new EndPoint("localhost", 9900 + brokerId, new ListenerName("blah"),
+        SecurityProtocol.PLAINTEXT)
+      Broker(brokerId, Seq(endpoint), rack = None) -> 1L
+    }.toMap
+    context.setLiveBrokerAndEpochs(brokerEpochs)
+
+    // Simple round-robin replica assignment
+    var leaderIndex = 0
+    for (topic <- topics; partitionId <- 0 until numPartitions) {
+      val partition = new TopicPartition(topic, partitionId)
+      val replicas = (0 until replicationFactor).map { i =>
+        val replica = brokers((i + leaderIndex) % brokers.size)
+        replica
+      }
+      context.updatePartitionReplicaAssignment(partition, replicas)
+      leaderIndex += 1
+    }
+    context
+  }
+
+}
diff --git a/core/src/test/scala/unit/kafka/server/LogDirFailureTest.scala b/core/src/test/scala/unit/kafka/server/LogDirFailureTest.scala
index 3eff38f..f8c56cb 100644
--- a/core/src/test/scala/unit/kafka/server/LogDirFailureTest.scala
+++ b/core/src/test/scala/unit/kafka/server/LogDirFailureTest.scala
@@ -193,7 +193,7 @@ class LogDirFailureTest extends IntegrationTestHarness {
 
     // The controller should have marked the replica on the original leader as offline
     val controllerServer = servers.find(_.kafkaController.isActive).get
-    val offlineReplicas = controllerServer.kafkaController.replicaStateMachine.replicasInState(topic, OfflineReplica)
+    val offlineReplicas = controllerServer.kafkaController.controllerContext.replicasInState(topic, OfflineReplica)
     assertTrue(offlineReplicas.contains(PartitionAndReplica(new TopicPartition(topic, 0), leaderServerId)))
   }