You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2020/01/16 22:51:22 UTC

[kafka] branch 2.4 updated: KAFKA-9235; Ensure transaction coordinator is stopped after replica deletion (#7963)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.4
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.4 by this push:
     new 65f2796  KAFKA-9235; Ensure transaction coordinator is stopped after replica deletion (#7963)
65f2796 is described below

commit 65f27964ad63e124296844cc35c7b2bea0dd52bd
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Thu Jan 16 13:50:24 2020 -0800

    KAFKA-9235; Ensure transaction coordinator is stopped after replica deletion (#7963)
    
    During a reassignment, it can happen that the current leader of a partition is demoted and removed from the replica set at the same time. In this case, we rely on the StopReplica request in order to stop replica fetchers and to clear the group coordinator cache. This patch adds similar logic to ensure that the transaction coordinator state cache also gets cleared.
    
    Reviewers: Rajini Sivaram <ra...@googlemail.com>
---
 .../kafka/coordinator/group/GroupCoordinator.scala |  14 ++-
 .../transaction/TransactionCoordinator.scala       |  89 ++++++++-----
 .../transaction/TransactionStateManager.scala      |  61 ++++-----
 core/src/main/scala/kafka/server/KafkaApis.scala   |  25 ++--
 .../TransactionCoordinatorConcurrencyTest.scala    |   4 +-
 .../transaction/TransactionCoordinatorTest.scala   |   8 +-
 .../transaction/TransactionStateManagerTest.scala  | 137 +++++++++++++++++----
 .../scala/unit/kafka/server/KafkaApisTest.scala    |  37 +++++-
 8 files changed, 264 insertions(+), 111 deletions(-)

diff --git a/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala b/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala
index 24a1780..2e2358d 100644
--- a/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala
@@ -860,11 +860,21 @@ class GroupCoordinator(val brokerId: Int,
     }
   }
 
-  def handleGroupImmigration(offsetTopicPartitionId: Int): Unit = {
+  /**
+   * Load cached state from the given partition and begin handling requests for groups which map to it.
+   *
+   * @param offsetTopicPartitionId The partition we are now leading
+   */
+  def onElection(offsetTopicPartitionId: Int): Unit = {
     groupManager.scheduleLoadGroupAndOffsets(offsetTopicPartitionId, onGroupLoaded)
   }
 
-  def handleGroupEmigration(offsetTopicPartitionId: Int): Unit = {
+  /**
+   * Unload cached state for the given partition and stop handling requests for groups which map to it.
+   *
+   * @param offsetTopicPartitionId The partition we are no longer leading
+   */
+  def onResignation(offsetTopicPartitionId: Int): Unit = {
     groupManager.removeGroupsForPartition(offsetTopicPartitionId, onGroupUnloaded)
   }
 
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index d646757..c7601ef 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -129,7 +129,7 @@ class TransactionCoordinator(brokerId: Int,
             state = Empty,
             topicPartitions = collection.mutable.Set.empty[TopicPartition],
             txnLastUpdateTimestamp = time.milliseconds())
-          txnManager.putTransactionStateIfNotExists(transactionalId, createdMetadata)
+          txnManager.putTransactionStateIfNotExists(createdMetadata)
 
         case Some(epochAndTxnMetadata) => Right(epochAndTxnMetadata)
       }
@@ -274,7 +274,13 @@ class TransactionCoordinator(brokerId: Int,
     }
   }
 
-  def handleTxnImmigration(txnTopicPartitionId: Int, coordinatorEpoch: Int): Unit = {
+  /**
+   * Load state from the given partition and begin handling requests for groups which map to this partition.
+   *
+   * @param txnTopicPartitionId The partition that we are now leading
+   * @param coordinatorEpoch The partition coordinator (or leader) epoch from the received LeaderAndIsr request
+   */
+  def onElection(txnTopicPartitionId: Int, coordinatorEpoch: Int): Unit = {
     // The operations performed during immigration must be resilient to any previous errors we saw or partial state we
     // left off during the unloading phase. Ensure we remove all associated state for this partition before we continue
     // loading it.
@@ -284,8 +290,20 @@ class TransactionCoordinator(brokerId: Int,
     txnManager.loadTransactionsForTxnTopicPartition(txnTopicPartitionId, coordinatorEpoch, txnMarkerChannelManager.addTxnMarkersToSend)
   }
 
-  def handleTxnEmigration(txnTopicPartitionId: Int, coordinatorEpoch: Int): Unit = {
-    txnManager.removeTransactionsForTxnTopicPartition(txnTopicPartitionId, coordinatorEpoch)
+  /**
+   * Clear coordinator caches for the given partition after giving up leadership.
+   *
+   * @param txnTopicPartitionId The partition that we are no longer leading
+   * @param coordinatorEpoch The partition coordinator (or leader) epoch, which may be absent if we
+   *                         are resigning after receiving a StopReplica request from the controller
+   */
+  def onResignation(txnTopicPartitionId: Int, coordinatorEpoch: Option[Int]): Unit = {
+    coordinatorEpoch match {
+      case Some(epoch) =>
+        txnManager.removeTransactionsForTxnTopicPartition(txnTopicPartitionId, epoch)
+      case None =>
+        txnManager.removeTransactionsForTxnTopicPartition(txnTopicPartitionId)
+    }
     txnMarkerChannelManager.removeMarkersForTxnTopicPartition(txnTopicPartitionId)
   }
 
@@ -450,47 +468,52 @@ class TransactionCoordinator(brokerId: Int,
   def partitionFor(transactionalId: String): Int = txnManager.partitionFor(transactionalId)
 
   private def abortTimedOutTransactions(): Unit = {
+    def onComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = {
+      error match {
+        case Errors.NONE =>
+          info("Completed rollback of ongoing transaction for transactionalId " +
+            s"${txnIdAndPidEpoch.transactionalId} due to timeout")
+
+        case error@(Errors.INVALID_PRODUCER_ID_MAPPING |
+                    Errors.INVALID_PRODUCER_EPOCH |
+                    Errors.CONCURRENT_TRANSACTIONS) =>
+          debug(s"Rollback of ongoing transaction for transactionalId ${txnIdAndPidEpoch.transactionalId} " +
+            s"has been cancelled due to error $error")
+
+        case error =>
+          warn(s"Rollback of ongoing transaction for transactionalId ${txnIdAndPidEpoch.transactionalId} " +
+            s"failed due to error $error")
+      }
+    }
+
     txnManager.timedOutTransactions().foreach { txnIdAndPidEpoch =>
-      txnManager.getTransactionState(txnIdAndPidEpoch.transactionalId).right.flatMap {
+      txnManager.getTransactionState(txnIdAndPidEpoch.transactionalId).right.foreach {
         case None =>
-          error(s"Could not find transaction metadata when trying to timeout transaction with transactionalId " +
-            s"${txnIdAndPidEpoch.transactionalId}. ProducerId: ${txnIdAndPidEpoch.producerId}. ProducerEpoch: " +
-            s"${txnIdAndPidEpoch.producerEpoch}")
-          Left(Errors.INVALID_TXN_STATE)
+          error(s"Could not find transaction metadata when trying to timeout transaction for $txnIdAndPidEpoch")
 
         case Some(epochAndTxnMetadata) =>
           val txnMetadata = epochAndTxnMetadata.transactionMetadata
-          val transitMetadata = txnMetadata.inLock {
+          val transitMetadataOpt = txnMetadata.inLock {
             if (txnMetadata.producerId != txnIdAndPidEpoch.producerId) {
               error(s"Found incorrect producerId when expiring transactionalId: ${txnIdAndPidEpoch.transactionalId}. " +
                 s"Expected producerId: ${txnIdAndPidEpoch.producerId}. Found producerId: " +
                 s"${txnMetadata.producerId}")
-              Left(Errors.INVALID_PRODUCER_ID_MAPPING)
+              None
             } else if (txnMetadata.pendingTransitionInProgress) {
-              Left(Errors.CONCURRENT_TRANSACTIONS)
+              debug(s"Skipping abort of timed out transaction $txnIdAndPidEpoch since there is a " +
+                "pending state transition")
+              None
             } else {
-              Right(txnMetadata.prepareFenceProducerEpoch())
+              Some(txnMetadata.prepareFenceProducerEpoch())
             }
           }
-          transitMetadata match {
-            case Right(txnTransitMetadata) =>
-              handleEndTransaction(txnMetadata.transactionalId,
-                txnTransitMetadata.producerId,
-                txnTransitMetadata.producerEpoch,
-                TransactionResult.ABORT,
-                {
-                  case Errors.NONE =>
-                    info(s"Completed rollback ongoing transaction of transactionalId: ${txnIdAndPidEpoch.transactionalId} due to timeout")
-                  case e @ (Errors.INVALID_PRODUCER_ID_MAPPING |
-                            Errors.INVALID_PRODUCER_EPOCH |
-                            Errors.CONCURRENT_TRANSACTIONS) =>
-                    debug(s"Rolling back ongoing transaction of transactionalId: ${txnIdAndPidEpoch.transactionalId} has aborted due to ${e.exceptionName}")
-                  case e =>
-                    warn(s"Rolling back ongoing transaction of transactionalId: ${txnIdAndPidEpoch.transactionalId} failed due to ${e.exceptionName}")
-                })
-              Right(txnTransitMetadata)
-            case (error) =>
-              Left(error)
+
+          transitMetadataOpt.foreach { txnTransitMetadata =>
+            handleEndTransaction(txnMetadata.transactionalId,
+              txnTransitMetadata.producerId,
+              txnTransitMetadata.producerEpoch,
+              TransactionResult.ABORT,
+              onComplete(txnIdAndPidEpoch))
           }
       }
     }
@@ -503,7 +526,7 @@ class TransactionCoordinator(brokerId: Int,
     info("Starting up.")
     scheduler.startup()
     scheduler.schedule("transaction-abort",
-      () => abortTimedOutTransactions,
+      abortTimedOutTransactions,
       txnConfig.abortTimedOutTransactionsIntervalMs,
       txnConfig.abortTimedOutTransactionsIntervalMs
     )
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index 7d731a1..36bc965 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -88,9 +88,6 @@ class TransactionStateManager(brokerId: Int,
   /** partitions of transaction topic that are being loaded, state lock should be called BEFORE accessing this set */
   private[transaction] val loadingPartitions: mutable.Set[TransactionPartitionAndLeaderEpoch] = mutable.Set()
 
-  /** partitions of transaction topic that are being removed, state lock should be called BEFORE accessing this set */
-  private[transaction] val leavingPartitions: mutable.Set[TransactionPartitionAndLeaderEpoch] = mutable.Set()
-
   /** transaction metadata cache indexed by assigned transaction topic partition ids */
   private[transaction] val transactionMetadataCache: mutable.Map[Int, TxnMetadataCacheEntry] = mutable.Map()
 
@@ -110,9 +107,7 @@ class TransactionStateManager(brokerId: Int,
   // visible for testing only
   private[transaction] def addLoadingPartition(partitionId: Int, coordinatorEpoch: Int): Unit = {
     val partitionAndLeaderEpoch = TransactionPartitionAndLeaderEpoch(partitionId, coordinatorEpoch)
-
     inWriteLock(stateLock) {
-      leavingPartitions.remove(partitionAndLeaderEpoch)
       loadingPartitions.add(partitionAndLeaderEpoch)
     }
   }
@@ -125,9 +120,7 @@ class TransactionStateManager(brokerId: Int,
   def timedOutTransactions(): Iterable[TransactionalIdAndProducerIdEpoch] = {
     val now = time.milliseconds()
     inReadLock(stateLock) {
-      transactionMetadataCache.filter { case (txnPartitionId, _) =>
-        !leavingPartitions.exists(_.txnPartitionId == txnPartitionId)
-      }.flatMap { case (_, entry) =>
+      transactionMetadataCache.flatMap { case (_, entry) =>
         entry.metadataPerTransactionalId.filter { case (_, txnMetadata) =>
           if (txnMetadata.pendingTransitionInProgress) {
             false
@@ -224,10 +217,10 @@ class TransactionStateManager(brokerId: Int,
   def getTransactionState(transactionalId: String): Either[Errors, Option[CoordinatorEpochAndTxnMetadata]] =
     getAndMaybeAddTransactionState(transactionalId, None)
 
-  def putTransactionStateIfNotExists(transactionalId: String,
-                                     txnMetadata: TransactionMetadata): Either[Errors, CoordinatorEpochAndTxnMetadata] =
-    getAndMaybeAddTransactionState(transactionalId, Some(txnMetadata))
+  def putTransactionStateIfNotExists(txnMetadata: TransactionMetadata): Either[Errors, CoordinatorEpochAndTxnMetadata] = {
+    getAndMaybeAddTransactionState(txnMetadata.transactionalId, Some(txnMetadata))
       .right.map(_.getOrElse(throw new IllegalStateException(s"Unexpected empty transaction metadata returned while putting $txnMetadata")))
+  }
 
   /**
    * Get the transaction metadata associated with the given transactional id, or an error if
@@ -242,8 +235,6 @@ class TransactionStateManager(brokerId: Int,
       val partitionId = partitionFor(transactionalId)
       if (loadingPartitions.exists(_.txnPartitionId == partitionId))
         Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)
-      else if (leavingPartitions.exists(_.txnPartitionId == partitionId))
-        Left(Errors.NOT_COORDINATOR)
       else {
         transactionMetadataCache.get(partitionId) match {
           case Some(cacheEntry) =>
@@ -396,7 +387,6 @@ class TransactionStateManager(brokerId: Int,
     val partitionAndLeaderEpoch = TransactionPartitionAndLeaderEpoch(partitionId, coordinatorEpoch)
 
     inWriteLock(stateLock) {
-      leavingPartitions.remove(partitionAndLeaderEpoch)
       loadingPartitions.add(partitionAndLeaderEpoch)
     }
 
@@ -423,7 +413,7 @@ class TransactionStateManager(brokerId: Int,
                     transactionsPendingForCompletion +=
                       TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.COMMIT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
                   case _ =>
-                    // nothing need to be done
+                    // nothing needs to be done
                 }
               }
           }
@@ -442,7 +432,18 @@ class TransactionStateManager(brokerId: Int,
       info(s"Completed loading transaction metadata from $topicPartition for coordinator epoch $coordinatorEpoch")
     }
 
-    scheduler.schedule(s"load-txns-for-partition-$topicPartition", () => loadTransactions)
+    scheduler.schedule(s"load-txns-for-partition-$topicPartition", loadTransactions)
+  }
+
+  def removeTransactionsForTxnTopicPartition(partitionId: Int): Unit = {
+    val topicPartition = new TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionId)
+    inWriteLock(stateLock) {
+      loadingPartitions.retain(_.txnPartitionId != partitionId)
+      transactionMetadataCache.remove(partitionId).foreach { txnMetadataCacheEntry =>
+        info(s"Unloaded transaction metadata $txnMetadataCacheEntry for $topicPartition following " +
+          s"local partition deletion")
+      }
+    }
   }
 
   /**
@@ -455,26 +456,14 @@ class TransactionStateManager(brokerId: Int,
 
     inWriteLock(stateLock) {
       loadingPartitions.remove(partitionAndLeaderEpoch)
-      leavingPartitions.add(partitionAndLeaderEpoch)
-    }
+      transactionMetadataCache.remove(partitionId) match {
+        case Some(txnMetadataCacheEntry) =>
+          info(s"Unloaded transaction metadata $txnMetadataCacheEntry for $topicPartition on become-follower transition")
 
-    def removeTransactions(): Unit = {
-      inWriteLock(stateLock) {
-        if (leavingPartitions.contains(partitionAndLeaderEpoch)) {
-          transactionMetadataCache.remove(partitionId) match {
-            case Some(txnMetadataCacheEntry) =>
-              info(s"Unloaded transaction metadata $txnMetadataCacheEntry for $topicPartition on become-follower transition")
-
-            case None =>
-              info(s"No cached transaction metadata found for $topicPartition during become-follower transition")
-          }
-
-          leavingPartitions.remove(partitionAndLeaderEpoch)
-        }
+        case None =>
+          info(s"No cached transaction metadata found for $topicPartition during become-follower transition")
       }
     }
-
-    scheduler.schedule(s"remove-txns-for-partition-$topicPartition", () => removeTransactions)
   }
 
   private def validateTransactionTopicPartitionCountIsStable(): Unit = {
@@ -679,7 +668,11 @@ private[transaction] case class TransactionConfig(transactionalIdExpirationMs: I
                                                   removeExpiredTransactionalIdsIntervalMs: Int = TransactionStateManager.DefaultRemoveExpiredTransactionalIdsIntervalMs,
                                                   requestTimeoutMs: Int = Defaults.RequestTimeoutMs)
 
-case class TransactionalIdAndProducerIdEpoch(transactionalId: String, producerId: Long, producerEpoch: Short)
+case class TransactionalIdAndProducerIdEpoch(transactionalId: String, producerId: Long, producerEpoch: Short) {
+  override def toString: String = {
+    s"(transactionalId=$transactionalId, producerId=$producerId, producerEpoch=$producerEpoch)"
+  }
+}
 
 case class TransactionPartitionAndLeaderEpoch(txnPartitionId: Int, coordinatorEpoch: Int)
 
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala
index e3d53a1..124b2e3 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -194,16 +194,16 @@ class KafkaApis(val requestChannel: RequestChannel,
       // leadership changes
       updatedLeaders.foreach { partition =>
         if (partition.topic == GROUP_METADATA_TOPIC_NAME)
-          groupCoordinator.handleGroupImmigration(partition.partitionId)
+          groupCoordinator.onElection(partition.partitionId)
         else if (partition.topic == TRANSACTION_STATE_TOPIC_NAME)
-          txnCoordinator.handleTxnImmigration(partition.partitionId, partition.getLeaderEpoch)
+          txnCoordinator.onElection(partition.partitionId, partition.getLeaderEpoch)
       }
 
       updatedFollowers.foreach { partition =>
         if (partition.topic == GROUP_METADATA_TOPIC_NAME)
-          groupCoordinator.handleGroupEmigration(partition.partitionId)
+          groupCoordinator.onResignation(partition.partitionId)
         else if (partition.topic == TRANSACTION_STATE_TOPIC_NAME)
-          txnCoordinator.handleTxnEmigration(partition.partitionId, partition.getLeaderEpoch)
+          txnCoordinator.onResignation(partition.partitionId, Some(partition.getLeaderEpoch))
       }
     }
 
@@ -234,15 +234,16 @@ class KafkaApis(val requestChannel: RequestChannel,
       sendResponseExemptThrottle(request, new StopReplicaResponse(new StopReplicaResponseData().setErrorCode(Errors.STALE_BROKER_EPOCH.code)))
     } else {
       val (result, error) = replicaManager.stopReplicas(stopReplicaRequest)
-      // Clearing out the cache for groups that belong to an offsets topic partition for which this broker was the leader,
-      // since this broker is no longer a replica for that offsets topic partition.
-      // This is required to handle the following scenario :
-      // Consider old replicas : {[1,2,3], Leader = 1} is reassigned to new replicas : {[2,3,4], Leader = 2}, broker 1 does not receive a LeaderAndIsr
-      // request to become a follower due to which cache for groups that belong to an offsets topic partition for which broker 1 was the leader,
-      // is not cleared.
+      // Clear the coordinator caches in case we were the leader. In the case of a reassignment, we
+      // cannot rely on the LeaderAndIsr API for this since it is only sent to active replicas.
       result.foreach { case (topicPartition, error) =>
-        if (error == Errors.NONE && stopReplicaRequest.deletePartitions && topicPartition.topic == GROUP_METADATA_TOPIC_NAME) {
-          groupCoordinator.handleGroupEmigration(topicPartition.partition)
+        if (error == Errors.NONE && stopReplicaRequest.deletePartitions) {
+          if (topicPartition.topic == GROUP_METADATA_TOPIC_NAME) {
+            groupCoordinator.onResignation(topicPartition.partition)
+          } else if (topicPartition.topic == TRANSACTION_STATE_TOPIC_NAME) {
+            // The StopReplica API does not pass through the leader epoch
+            txnCoordinator.onResignation(topicPartition.partition, coordinatorEpoch = None)
+          }
         }
       }
 
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index 3031671..79611cb 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -344,7 +344,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
 
   class LoadTxnPartitionAction(txnTopicPartitionId: Int) extends Action {
     override def run(): Unit = {
-      transactionCoordinator.handleTxnImmigration(txnTopicPartitionId, coordinatorEpoch)
+      transactionCoordinator.onElection(txnTopicPartitionId, coordinatorEpoch)
     }
     override def await(): Unit = {
       allTransactions.foreach { txn =>
@@ -358,7 +358,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
   class UnloadTxnPartitionAction(txnTopicPartitionId: Int) extends Action {
     val txnRecords: mutable.ArrayBuffer[SimpleRecord] = mutable.ArrayBuffer[SimpleRecord]()
     override def run(): Unit = {
-      transactionCoordinator.handleTxnEmigration(txnTopicPartitionId, coordinatorEpoch)
+      transactionCoordinator.onResignation(txnTopicPartitionId, Some(coordinatorEpoch))
     }
     override def await(): Unit = {
       allTransactions.foreach { txn =>
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index 75f06d5..a637d16 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -107,7 +107,7 @@ class TransactionCoordinatorTest {
       .andReturn(Right(None))
       .once()
 
-    EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.eq(transactionalId), EasyMock.capture(capturedTxn)))
+    EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.capture(capturedTxn)))
       .andAnswer(new IAnswer[Either[Errors, CoordinatorEpochAndTxnMetadata]] {
         override def answer(): Either[Errors, CoordinatorEpochAndTxnMetadata] = {
           assertTrue(capturedTxn.hasCaptured)
@@ -512,7 +512,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
 
-    EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.eq(transactionalId), EasyMock.anyObject[TransactionMetadata]()))
+    EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.anyObject[TransactionMetadata]()))
       .andReturn(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
       .anyTimes()
 
@@ -551,7 +551,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
 
-    EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.eq(transactionalId), EasyMock.anyObject[TransactionMetadata]()))
+    EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.anyObject[TransactionMetadata]()))
       .andReturn(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
       .anyTimes()
 
@@ -593,7 +593,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionMarkerChannelManager.removeMarkersForTxnTopicPartition(0))
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
-    coordinator.handleTxnEmigration(0, coordinatorEpoch)
+    coordinator.onResignation(0, Some(coordinatorEpoch))
 
     EasyMock.verify(transactionManager, transactionMarkerChannelManager)
   }
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index 5ab3f82..7f6ccfb 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -18,12 +18,13 @@ package kafka.coordinator.transaction
 
 import java.lang.management.ManagementFactory
 import java.nio.ByteBuffer
+import java.util.concurrent.CountDownLatch
 import java.util.concurrent.locks.ReentrantLock
 
 import javax.management.ObjectName
 import kafka.log.{AppendOrigin, Log}
 import kafka.server.{FetchDataInfo, FetchLogEnd, LogOffsetMetadata, ReplicaManager}
-import kafka.utils.{MockScheduler, Pool}
+import kafka.utils.{MockScheduler, Pool, TestUtils}
 import kafka.zk.KafkaZkClient
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME
@@ -104,11 +105,103 @@ class TransactionStateManagerTest {
 
     assertEquals(Right(None), transactionManager.getTransactionState(transactionalId1))
     assertEquals(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)),
-      transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1))
+      transactionManager.putTransactionStateIfNotExists(txnMetadata1))
     assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))),
       transactionManager.getTransactionState(transactionalId1))
-    assertEquals(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)),
-      transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata2))
+    assertEquals(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata2)),
+      transactionManager.putTransactionStateIfNotExists(txnMetadata2))
+  }
+
+  @Test
+  def testDeletePartition(): Unit = {
+    val metadata1 = transactionMetadata("b", 5L)
+    val metadata2 = transactionMetadata("a", 10L)
+
+    assertEquals(0, transactionManager.partitionFor(metadata1.transactionalId))
+    assertEquals(1, transactionManager.partitionFor(metadata2.transactionalId))
+
+    transactionManager.addLoadedTransactionsToCache(0, coordinatorEpoch, new Pool[String, TransactionMetadata]())
+    transactionManager.putTransactionStateIfNotExists(metadata1)
+
+    transactionManager.addLoadedTransactionsToCache(1, coordinatorEpoch, new Pool[String, TransactionMetadata]())
+    transactionManager.putTransactionStateIfNotExists(metadata2)
+
+    def cachedProducerEpoch(transactionalId: String): Option[Short] = {
+      transactionManager.getTransactionState(transactionalId).toOption.flatten
+        .map(_.transactionMetadata.producerEpoch)
+    }
+
+    assertEquals(Some(metadata1.producerEpoch), cachedProducerEpoch(metadata1.transactionalId))
+    assertEquals(Some(metadata2.producerEpoch), cachedProducerEpoch(metadata2.transactionalId))
+
+    transactionManager.removeTransactionsForTxnTopicPartition(0)
+
+    assertEquals(None, cachedProducerEpoch(metadata1.transactionalId))
+    assertEquals(Some(metadata2.producerEpoch), cachedProducerEpoch(metadata2.transactionalId))
+  }
+
+  @Test
+  def testDeleteLoadingPartition(): Unit = {
+    // Verify the handling of a call to delete state for a partition while it is in the
+    // process of being loaded. Basically should be treated as a no-op.
+
+    val startOffset = 0L
+    val endOffset = 1L
+
+    val fileRecordsMock = EasyMock.mock[FileRecords](classOf[FileRecords])
+    val logMock = EasyMock.mock[Log](classOf[Log])
+    EasyMock.expect(replicaManager.getLog(topicPartition)).andStubReturn(Some(logMock))
+    EasyMock.expect(logMock.logStartOffset).andStubReturn(startOffset)
+    EasyMock.expect(logMock.read(EasyMock.eq(startOffset),
+      maxLength = EasyMock.anyInt(),
+      isolation = EasyMock.eq(FetchLogEnd),
+      minOneMessage = EasyMock.eq(true))
+    ).andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), fileRecordsMock))
+    EasyMock.expect(replicaManager.getLogEndOffset(topicPartition)).andStubReturn(Some(endOffset))
+
+    txnMetadata1.state = PrepareCommit
+    txnMetadata1.addPartitions(Set[TopicPartition](
+      new TopicPartition("topic1", 0),
+      new TopicPartition("topic1", 1)))
+    val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE,
+      new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit())))
+
+    // We create a latch which is awaited while the log is loading. This ensures that the deletion
+    // is triggered before the loading returns
+    val latch = new CountDownLatch(1)
+
+    EasyMock.expect(fileRecordsMock.sizeInBytes()).andStubReturn(records.sizeInBytes)
+    val bufferCapture = EasyMock.newCapture[ByteBuffer]
+    fileRecordsMock.readInto(EasyMock.capture(bufferCapture), EasyMock.anyInt())
+    EasyMock.expectLastCall().andAnswer(new IAnswer[Unit] {
+      override def answer: Unit = {
+        latch.await()
+        val buffer = bufferCapture.getValue
+        buffer.put(records.buffer.duplicate)
+        buffer.flip()
+      }
+    })
+
+    EasyMock.replay(logMock, fileRecordsMock, replicaManager)
+
+    val coordinatorEpoch = 0
+    val partitionAndLeaderEpoch = TransactionPartitionAndLeaderEpoch(partitionId, coordinatorEpoch)
+
+    val loadingThread = new Thread(() => {
+      transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch, (_, _, _, _, _) => ())
+    })
+    loadingThread.start()
+    TestUtils.waitUntilTrue(() => transactionManager.loadingPartitions.contains(partitionAndLeaderEpoch),
+      "Timed out waiting for loading partition", pause = 10)
+
+    transactionManager.removeTransactionsForTxnTopicPartition(partitionId)
+    assertFalse(transactionManager.loadingPartitions.contains(partitionAndLeaderEpoch))
+
+    latch.countDown()
+    loadingThread.join()
+
+    // Verify that transaction state was not loaded
+    assertEquals(Left(Errors.NOT_COORDINATOR), transactionManager.getTransactionState(txnMetadata1.transactionalId))
   }
 
   @Test
@@ -216,7 +309,7 @@ class TransactionStateManagerTest {
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
 
     // first insert the initial transaction metadata
-    transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata1)
 
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.NONE
@@ -235,7 +328,7 @@ class TransactionStateManagerTest {
   @Test
   def testAppendFailToCoordinatorNotAvailableError(): Unit = {
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
-    transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata1)
 
     expectedError = Errors.COORDINATOR_NOT_AVAILABLE
     var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds())
@@ -267,7 +360,7 @@ class TransactionStateManagerTest {
   @Test
   def testAppendFailToNotCoordinatorError(): Unit = {
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
-    transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata1)
 
     expectedError = Errors.NOT_COORDINATOR
     var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds())
@@ -285,7 +378,7 @@ class TransactionStateManagerTest {
     prepareForTxnMessageAppend(Errors.NONE)
     transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch)
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch + 1, new Pool[String, TransactionMetadata]())
-    transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata1)
     transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
 
     prepareForTxnMessageAppend(Errors.NONE)
@@ -297,7 +390,7 @@ class TransactionStateManagerTest {
   @Test
   def testAppendFailToCoordinatorLoadingError(): Unit = {
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
-    transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata1)
 
     expectedError = Errors.COORDINATOR_LOAD_IN_PROGRESS
     val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds())
@@ -311,7 +404,7 @@ class TransactionStateManagerTest {
   @Test
   def testAppendFailToUnknownError(): Unit = {
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
-    transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata1)
 
     expectedError = Errors.UNKNOWN_SERVER_ERROR
     var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds())
@@ -331,7 +424,7 @@ class TransactionStateManagerTest {
   @Test
   def testPendingStateNotResetOnRetryAppend(): Unit = {
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
-    transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata1)
 
     expectedError = Errors.COORDINATOR_NOT_AVAILABLE
     val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds())
@@ -347,7 +440,7 @@ class TransactionStateManagerTest {
     transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
 
     // first insert the initial transaction metadata
-    transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata1)
 
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.NOT_COORDINATOR
@@ -366,7 +459,7 @@ class TransactionStateManagerTest {
   def testAppendTransactionToLogWhilePendingStateChanged() = {
     // first insert the initial transaction metadata
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
-    transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata1)
 
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.INVALID_PRODUCER_EPOCH
@@ -395,12 +488,12 @@ class TransactionStateManagerTest {
       transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
     }
 
-    transactionManager.putTransactionStateIfNotExists("ongoing", transactionMetadata("ongoing", producerId = 0, state = Ongoing))
-    transactionManager.putTransactionStateIfNotExists("not-expiring", transactionMetadata("not-expiring", producerId = 1, state = Ongoing, txnTimeout = 10000))
-    transactionManager.putTransactionStateIfNotExists("prepare-commit", transactionMetadata("prepare-commit", producerId = 2, state = PrepareCommit))
-    transactionManager.putTransactionStateIfNotExists("prepare-abort", transactionMetadata("prepare-abort", producerId = 3, state = PrepareAbort))
-    transactionManager.putTransactionStateIfNotExists("complete-commit", transactionMetadata("complete-commit", producerId = 4, state = CompleteCommit))
-    transactionManager.putTransactionStateIfNotExists("complete-abort", transactionMetadata("complete-abort", producerId = 5, state = CompleteAbort))
+    transactionManager.putTransactionStateIfNotExists(transactionMetadata("ongoing", producerId = 0, state = Ongoing))
+    transactionManager.putTransactionStateIfNotExists(transactionMetadata("not-expiring", producerId = 1, state = Ongoing, txnTimeout = 10000))
+    transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-commit", producerId = 2, state = PrepareCommit))
+    transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-abort", producerId = 3, state = PrepareAbort))
+    transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-commit", producerId = 4, state = CompleteCommit))
+    transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-abort", producerId = 5, state = CompleteAbort))
 
     time.sleep(2000)
     val expiring = transactionManager.timedOutTransactions()
@@ -481,13 +574,11 @@ class TransactionStateManagerTest {
     // immigrate partition at epoch 0
     transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch = 0, (_, _, _, _, _) => ())
     assertEquals(0, transactionManager.loadingPartitions.size)
-    assertEquals(0, transactionManager.leavingPartitions.size)
 
     // Re-immigrate partition at epoch 1. This should be successful even though we didn't get to emigrate the partition.
     prepareTxnLog(topicPartition, 0, records)
     transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch = 1, (_, _, _, _, _) => ())
     assertEquals(0, transactionManager.loadingPartitions.size)
-    assertEquals(0, transactionManager.leavingPartitions.size)
     assertTrue(transactionManager.transactionMetadataCache.get(partitionId).isDefined)
     assertEquals(1, transactionManager.transactionMetadataCache.get(partitionId).get.coordinatorEpoch)
   }
@@ -578,10 +669,10 @@ class TransactionStateManagerTest {
 
     txnMetadata1.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs
     txnMetadata1.state = txnState
-    transactionManager.putTransactionStateIfNotExists(transactionalId1, txnMetadata1)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata1)
 
     txnMetadata2.txnLastUpdateTimestamp = time.milliseconds()
-    transactionManager.putTransactionStateIfNotExists(transactionalId2, txnMetadata2)
+    transactionManager.putTransactionStateIfNotExists(txnMetadata2)
 
     transactionManager.enableTransactionalIdExpiration()
     time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs)
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index 773ea29..a3f58e9 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -36,6 +36,7 @@ import kafka.utils.{MockTime, TestUtils}
 import kafka.zk.KafkaZkClient
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.UnsupportedVersionException
+import org.apache.kafka.common.internals.Topic
 import org.apache.kafka.common.memory.MemoryPool
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.network.ListenerName
@@ -59,7 +60,7 @@ import org.junit.Assert.{assertArrayEquals, assertEquals, assertNull, assertTrue
 import org.junit.{After, Test}
 
 import scala.collection.JavaConverters._
-import scala.collection.{Map, Seq}
+import scala.collection.{Map, Seq, mutable}
 
 class KafkaApisTest {
 
@@ -316,6 +317,40 @@ class KafkaApisTest {
   }
 
   @Test
+  def shouldResignCoordinatorsIfStopReplicaReceivedWithDeleteFlag(): Unit = {
+    val controllerId = 0
+    val controllerEpoch = 5
+    val brokerEpoch = 230498320L
+
+    val groupMetadataPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0)
+    val txnStatePartition = new TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, 0)
+
+    val (_, request) = buildRequest(new StopReplicaRequest.Builder(
+      ApiKeys.STOP_REPLICA.latestVersion,
+      controllerId,
+      controllerEpoch,
+      brokerEpoch,
+      true,
+      Set(groupMetadataPartition, txnStatePartition).asJava))
+
+    EasyMock.expect(replicaManager.stopReplicas(anyObject())).andReturn(
+      (mutable.Map(groupMetadataPartition -> Errors.NONE, txnStatePartition -> Errors.NONE), Errors.NONE))
+    EasyMock.expect(controller.brokerEpoch).andStubReturn(brokerEpoch)
+
+    txnCoordinator.onResignation(txnStatePartition.partition, None)
+    EasyMock.expectLastCall()
+
+    groupCoordinator.onResignation(groupMetadataPartition.partition)
+    EasyMock.expectLastCall()
+
+    EasyMock.replay(controller, replicaManager, txnCoordinator, groupCoordinator)
+
+    createKafkaApis().handleStopReplicaRequest(request)
+
+    EasyMock.verify(txnCoordinator, groupCoordinator)
+  }
+
+  @Test
   def shouldRespondWithUnknownTopicOrPartitionForBadPartitionAndNoErrorsForGoodPartition(): Unit = {
     val tp1 = new TopicPartition("t", 0)
     val tp2 = new TopicPartition("t1", 0)