You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2017/05/25 18:01:13 UTC

kafka git commit: KAFKA-5279: TransactionCoordinator must expire transactionalIds

Repository: kafka
Updated Branches:
  refs/heads/trunk 64fc1a7ca -> 20e200878


KAFKA-5279: TransactionCoordinator must expire transactionalIds

remove transactions that have not been updated for at least `transactional.id.expiration.ms`

Author: Damian Guy <da...@gmail.com>

Reviewers: Apurva Mehta, Guozhang Wang

Closes #3101 from dguy/kafka-5279


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/20e20087
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/20e20087
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/20e20087

Branch: refs/heads/trunk
Commit: 20e2008785d46aa0500b02d8737380c50d66da3b
Parents: 64fc1a7
Author: Damian Guy <da...@gmail.com>
Authored: Thu May 25 11:01:10 2017 -0700
Committer: Guozhang Wang <wa...@gmail.com>
Committed: Thu May 25 11:01:10 2017 -0700

----------------------------------------------------------------------
 .../transaction/TransactionCoordinator.scala    |  22 ++--
 .../transaction/TransactionMetadata.scala       |  16 ++-
 .../transaction/TransactionStateManager.scala   |  98 ++++++++++++++-
 .../main/scala/kafka/server/KafkaConfig.scala   |  16 ++-
 .../TransactionCoordinatorTest.scala            |  11 +-
 .../TransactionStateManagerTest.scala           | 119 ++++++++++++++++++-
 6 files changed, 252 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/20e20087/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index b31c0bc..f182420 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -45,7 +45,9 @@ object TransactionCoordinator {
       config.transactionTopicSegmentBytes,
       config.transactionsLoadBufferSize,
       config.transactionTopicMinISR,
-      config.transactionTransactionsExpiredTransactionCleanupIntervalMs)
+      config.transactionAbortTimedOutTransactionCleanupIntervalMs,
+      config.transactionRemoveExpiredTransactionalIdCleanupIntervalMs,
+      config.requestTimeoutMs)
 
     val producerIdManager = new ProducerIdManager(config.brokerId, zkUtils)
     // we do not need to turn on reaper thread since no tasks will be expired and there are no completed tasks to be purged
@@ -404,8 +406,8 @@ class TransactionCoordinator(brokerId: Int,
 
   def partitionFor(transactionalId: String): Int = txnManager.partitionFor(transactionalId)
 
-  private def expireTransactions(): Unit = {
-    txnManager.transactionsToExpire().foreach { txnIdAndPidEpoch =>
+  private def abortTimedOutTransactions(): Unit = {
+    txnManager.timedOutTransactions().foreach { txnIdAndPidEpoch =>
       handleEndTransaction(txnIdAndPidEpoch.transactionalId,
         txnIdAndPidEpoch.producerId,
         txnIdAndPidEpoch.producerEpoch,
@@ -426,16 +428,16 @@ class TransactionCoordinator(brokerId: Int,
   /**
    * Startup logic executed at the same time when the server starts up.
    */
-  def startup(enablePidExpiration: Boolean = true) {
+  def startup(enableTransactionalIdExpiration: Boolean = true) {
     info("Starting up.")
     scheduler.startup()
-    scheduler.schedule("transaction-expiration",
-      expireTransactions,
-      TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs,
-      TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs
+    scheduler.schedule("transaction-abort",
+      abortTimedOutTransactions,
+      TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs,
+      TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs
     )
-    if (enablePidExpiration)
-      txnManager.enableProducerIdExpiration()
+    if (enableTransactionalIdExpiration)
+      txnManager.enableTransactionalIdExpiration()
     txnMarkerChannelManager.start()
     isActive.set(true)
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/20e20087/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index e1abf0e..5956f1d 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -69,6 +69,11 @@ private[transaction] case object CompleteCommit extends TransactionState { val b
  */
 private[transaction] case object CompleteAbort extends TransactionState { val byte: Byte = 5 }
 
+/**
+  * TransactionalId has expired and is about to be removed from the transaction cache
+  */
+private[transaction] case object Dead extends TransactionState { val byte: Byte = 6 }
+
 private[transaction] object TransactionMetadata {
   def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, timestamp: Long) =
     new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, Empty,
@@ -87,6 +92,7 @@ private[transaction] object TransactionMetadata {
       case 3 => PrepareAbort
       case 4 => CompleteCommit
       case 5 => CompleteAbort
+      case 6 => Dead
       case unknown => throw new IllegalStateException("Unknown transaction state byte " + unknown + " from the transaction status message")
     }
   }
@@ -100,7 +106,8 @@ private[transaction] object TransactionMetadata {
       PrepareCommit -> Set(Ongoing),
       PrepareAbort -> Set(Ongoing),
       CompleteCommit -> Set(PrepareCommit),
-      CompleteAbort -> Set(PrepareAbort))
+      CompleteAbort -> Set(PrepareAbort),
+      Dead -> Set(Empty, CompleteAbort, CompleteCommit))
 }
 
 // this is a immutable object representing the target transition of the transaction metadata
@@ -141,7 +148,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
                                                var state: TransactionState,
                                                val topicPartitions: mutable.Set[TopicPartition],
                                                @volatile var txnStartTimestamp: Long = -1,
-                                               var txnLastUpdateTimestamp: Long) extends Logging {
+                                               @volatile var txnLastUpdateTimestamp: Long) extends Logging {
 
   // pending state is used to indicate the state that this transaction is going to
   // transit to, and for blocking future attempts to transit it again if it is not legal;
@@ -207,6 +214,11 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
     prepareTransitionTo(newState, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition], txnStartTimestamp, updateTimestamp)
   }
 
+
+  def prepareDead : TxnTransitMetadata = {
+    prepareTransitionTo(Dead, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition], txnStartTimestamp, txnLastUpdateTimestamp)
+  }
+
   private def prepareTransitionTo(newState: TransactionState,
                                   newEpoch: Short,
                                   newTxnTimeoutMs: Int,

http://git-wip-us.apache.org/repos/asf/kafka/blob/20e20087/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index e8f8e4d..0d7b5c4 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -25,6 +25,8 @@ import java.util.concurrent.locks.ReentrantReadWriteLock
 import kafka.common.KafkaException
 import kafka.log.LogConfig
 import kafka.message.UncompressedCodec
+import kafka.server.Defaults
+import kafka.utils.CoreUtils.inLock
 import kafka.server.ReplicaManager
 import kafka.utils.CoreUtils.{inReadLock, inWriteLock}
 import kafka.utils.{Logging, Pool, Scheduler, ZkUtils}
@@ -43,10 +45,10 @@ import scala.collection.JavaConverters._
 
 object TransactionStateManager {
   // default transaction management config values
-  // TODO: this needs to be replaces by the config values
   val DefaultTransactionsMaxTimeoutMs: Int = TimeUnit.MINUTES.toMillis(15).toInt
   val DefaultTransactionalIdExpirationMs: Int = TimeUnit.DAYS.toMillis(7).toInt
-  val DefaultRemoveExpiredTransactionsIntervalMs: Int = TimeUnit.MINUTES.toMillis(1).toInt
+  val DefaultAbortTimedOutTransactionsIntervalMs: Int = TimeUnit.MINUTES.toMillis(1).toInt
+  val DefaultRemoveExpiredTransactionalIdsIntervalMs: Int = TimeUnit.HOURS.toMillis(1).toInt
 }
 
 /**
@@ -89,7 +91,7 @@ class TransactionStateManager(brokerId: Int,
   // txn timeout value, we do not need to grab the lock on the metadata object upon checking its state
   // since the timestamp is volatile and we will get the lock when actually trying to transit the transaction
   // metadata to abort later.
-  def transactionsToExpire(): Iterable[TransactionalIdAndProducerIdEpoch] = {
+  def timedOutTransactions(): Iterable[TransactionalIdAndProducerIdEpoch] = {
     val now = time.milliseconds()
     inReadLock(stateLock) {
       transactionMetadataCache.filter { case (txnPartitionId, _) =>
@@ -112,8 +114,87 @@ class TransactionStateManager(brokerId: Int,
     }
   }
 
-  def enableProducerIdExpiration() {
-    // TODO: add producer id expiration logic
+
+
+  def enableTransactionalIdExpiration() {
+    scheduler.schedule("transactionalId-expiration", () => {
+      val now = time.milliseconds()
+      inReadLock(stateLock) {
+        val transactionalIdByPartition: Map[Int, mutable.Iterable[TransactionalIdCoordinatorEpochAndMetadata]] =
+          transactionMetadataCache.flatMap { case (partition, entry) =>
+            entry.metadataPerTransactionalId.filter { case (_, txnMetadata) => txnMetadata.state match {
+              case Empty | CompleteCommit | CompleteAbort => true
+              case _ => false
+            }
+            }.filter { case (_, txnMetadata) =>
+              txnMetadata.txnLastUpdateTimestamp <= now - config.transactionalIdExpirationMs
+            }.map { case (transactionalId, txnMetadata) =>
+              val txnMetadataTransition = txnMetadata synchronized {
+                txnMetadata.prepareDead
+              }
+              TransactionalIdCoordinatorEpochAndMetadata(transactionalId, entry.coordinatorEpoch, txnMetadataTransition)
+            }
+          }.groupBy { transactionalIdCoordinatorEpochAndMetadata =>
+            partitionFor(transactionalIdCoordinatorEpochAndMetadata.transactionalId)
+          }
+
+        val recordsPerPartition = transactionalIdByPartition
+          .map { case (partition, transactionalIdCoordinatorEpochAndMetadatas) =>
+            val deletes: Array[SimpleRecord] = transactionalIdCoordinatorEpochAndMetadatas.map { entry =>
+              new SimpleRecord(now, TransactionLog.keyToBytes(entry.transactionalId), null)
+            }.toArray
+            val records = MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, deletes: _*)
+            val topicPartition = new TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partition)
+            (topicPartition, records)
+          }
+
+
+        def removeFromCacheCallback(responses: collection.Map[TopicPartition, PartitionResponse]): Unit = {
+          responses.foreach { case (topicPartition, response) =>
+            response.error match {
+              case Errors.NONE =>
+                inReadLock(stateLock) {
+                  val toRemove = transactionalIdByPartition(topicPartition.partition())
+                  transactionMetadataCache.get(topicPartition.partition)
+                    .foreach { txnMetadataCacheEntry =>
+                      toRemove.foreach { idCoordinatorEpochAndMetadata =>
+                        val txnMetadata = txnMetadataCacheEntry.metadataPerTransactionalId.get(idCoordinatorEpochAndMetadata.transactionalId)
+                        txnMetadata synchronized {
+                          if (txnMetadataCacheEntry.coordinatorEpoch == idCoordinatorEpochAndMetadata.coordinatorEpoch
+                            && txnMetadata.pendingState.contains(Dead)
+                            && txnMetadata.producerEpoch == idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
+                          )
+                            txnMetadataCacheEntry.metadataPerTransactionalId.remove(idCoordinatorEpochAndMetadata.transactionalId)
+                          else {
+                             debug(s"failed to remove expired transactionalId: ${idCoordinatorEpochAndMetadata.transactionalId}" +
+                               s" from cache. pendingState: ${txnMetadata.pendingState} producerEpoch: ${txnMetadata.producerEpoch}" +
+                               s" expected producerEpoch: ${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}" +
+                               s" coordinatorEpoch: ${txnMetadataCacheEntry.coordinatorEpoch} expected coordinatorEpoch: " +
+                               s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}")
+                            txnMetadata.pendingState = None
+                          }
+                        }
+                      }
+                    }
+                }
+              case _ =>
+                debug(s"writing transactionalId tombstones for partition: ${topicPartition.partition} failed with error: ${response.error.message()}")
+            }
+          }
+        }
+
+        replicaManager.appendRecords(
+          config.requestTimeoutMs,
+          TransactionLog.EnforcedRequiredAcks,
+          internalTopicsAllowed = true,
+          isFromClient = false,
+          recordsPerPartition,
+          removeFromCacheCallback,
+          None
+        )
+      }
+
+    }, delay = config.removeExpiredTransactionalIdsIntervalMs, period = config.removeExpiredTransactionalIdsIntervalMs)
   }
 
   /**
@@ -524,8 +605,13 @@ private[transaction] case class TransactionConfig(transactionalIdExpirationMs: I
                                                   transactionLogSegmentBytes: Int = TransactionLog.DefaultSegmentBytes,
                                                   transactionLogLoadBufferSize: Int = TransactionLog.DefaultLoadBufferSize,
                                                   transactionLogMinInsyncReplicas: Int = TransactionLog.DefaultMinInSyncReplicas,
-                                                  removeExpiredTransactionsIntervalMs: Int = TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
+                                                  abortTimedOutTransactionsIntervalMs: Int = TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs,
+                                                  removeExpiredTransactionalIdsIntervalMs: Int = TransactionStateManager.DefaultRemoveExpiredTransactionalIdsIntervalMs,
+                                                  requestTimeoutMs: Int = Defaults.RequestTimeoutMs)
 
 case class TransactionalIdAndProducerIdEpoch(transactionalId: String, producerId: Long, producerEpoch: Short)
 
 case class TransactionPartitionAndLeaderEpoch(txnPartitionId: Int, coordinatorEpoch: Int)
+case class TransactionalIdCoordinatorEpochAndMetadata(transactionalId: String,
+                                                      coordinatorEpoch: Int,
+                                                      transitMetadata: TxnTransitMetadata)

http://git-wip-us.apache.org/repos/asf/kafka/blob/20e20087/core/src/main/scala/kafka/server/KafkaConfig.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala
index 99eddab..de036a7 100755
--- a/core/src/main/scala/kafka/server/KafkaConfig.scala
+++ b/core/src/main/scala/kafka/server/KafkaConfig.scala
@@ -165,7 +165,8 @@ object Defaults {
   val TransactionsTopicReplicationFactor = TransactionLog.DefaultReplicationFactor
   val TransactionsTopicPartitions = TransactionLog.DefaultNumPartitions
   val TransactionsTopicSegmentBytes = TransactionLog.DefaultSegmentBytes
-  val TransactionsExpiredTransactionCleanupIntervalMS = TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs
+  val TransactionsAbortTimedOutTransactionsCleanupIntervalMS = TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs
+  val TransactionsRemoveExpiredTransactionsCleanupIntervalMS = TransactionStateManager.DefaultRemoveExpiredTransactionalIdsIntervalMs
 
   /** ********* Quota Configuration ***********/
   val ProducerQuotaBytesPerSecondDefault = ClientQuotaManagerConfig.QuotaBytesPerSecondDefault
@@ -350,7 +351,8 @@ object KafkaConfig {
   val TransactionsTopicPartitionsProp = "transaction.state.log.num.partitions"
   val TransactionsTopicSegmentBytesProp = "transaction.state.log.segment.bytes"
   val TransactionsTopicReplicationFactorProp = "transaction.state.log.replication.factor"
-  val TransactionsExpiredTransactionCleanupIntervalMsProp = "transaction.expired.transaction.cleanup.interval.ms"
+  val TransactionsAbortTimedOutTransactionCleanupIntervalMsProp = "transaction.abort.timed.out.transaction.cleanup.interval.ms"
+  val TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp = "transaction.remove.expired.transaction.cleanup.interval.ms"
 
   /** ********* Quota Configuration ***********/
   val ProducerQuotaBytesPerSecondDefaultProp = "quota.producer.default"
@@ -597,7 +599,8 @@ object KafkaConfig {
     "Internal topic creation will fail until the cluster size meets this replication factor requirement."
   val TransactionsTopicPartitionsDoc = "The number of partitions for the transaction topic (should not change after deployment)."
   val TransactionsTopicSegmentBytesDoc = "The transaction topic segment bytes should be kept relatively small in order to facilitate faster log compaction and cache loads"
-  val TransactionsExpiredTransactionCleanupIntervalMsDoc = "The interval at which to rollback expired transactions"
+  val TransactionsAbortTimedOutTransactionsIntervalMsDoc = "The interval at which to rollback transactions that have timed out"
+  val TransactionsRemoveExpiredTransactionsIntervalMsDoc = "The interval at which to remove transactions that have expired due to <code>transactional.id.expiration.ms<code> passing"
 
   /** ********* Quota Configuration ***********/
   val ProducerQuotaBytesPerSecondDefaultDoc = "DEPRECATED: Used only when dynamic default quotas are not configured for <user>, <client-id> or <user, client-id> in Zookeeper. " +
@@ -803,7 +806,8 @@ object KafkaConfig {
       .define(TransactionsTopicReplicationFactorProp, SHORT, Defaults.TransactionsTopicReplicationFactor, atLeast(1), HIGH, TransactionsTopicReplicationFactorDoc)
       .define(TransactionsTopicPartitionsProp, INT, Defaults.TransactionsTopicPartitions, atLeast(1), HIGH, TransactionsTopicPartitionsDoc)
       .define(TransactionsTopicSegmentBytesProp, INT, Defaults.TransactionsTopicSegmentBytes, atLeast(1), HIGH, TransactionsTopicSegmentBytesDoc)
-      .define(TransactionsExpiredTransactionCleanupIntervalMsProp, INT, Defaults.TransactionsExpiredTransactionCleanupIntervalMS, atLeast(1), LOW, TransactionsExpiredTransactionCleanupIntervalMsDoc)
+      .define(TransactionsAbortTimedOutTransactionCleanupIntervalMsProp, INT, Defaults.TransactionsAbortTimedOutTransactionsCleanupIntervalMS, atLeast(1), LOW, TransactionsAbortTimedOutTransactionsIntervalMsDoc)
+      .define(TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp, INT, Defaults.TransactionsRemoveExpiredTransactionsCleanupIntervalMS, atLeast(1), LOW, TransactionsRemoveExpiredTransactionsIntervalMsDoc)
 
       /** ********* Kafka Metrics Configuration ***********/
       .define(MetricNumSamplesProp, INT, Defaults.MetricNumSamples, atLeast(1), LOW, MetricNumSamplesDoc)
@@ -1013,7 +1017,9 @@ class KafkaConfig(val props: java.util.Map[_, _], doLog: Boolean) extends Abstra
   val transactionTopicReplicationFactor = getShort(KafkaConfig.TransactionsTopicReplicationFactorProp)
   val transactionTopicPartitions = getInt(KafkaConfig.TransactionsTopicPartitionsProp)
   val transactionTopicSegmentBytes = getInt(KafkaConfig.TransactionsTopicSegmentBytesProp)
-  val transactionTransactionsExpiredTransactionCleanupIntervalMs = getInt(KafkaConfig.TransactionsExpiredTransactionCleanupIntervalMsProp)
+  val transactionAbortTimedOutTransactionCleanupIntervalMs = getInt(KafkaConfig.TransactionsAbortTimedOutTransactionCleanupIntervalMsProp)
+  val transactionRemoveExpiredTransactionalIdCleanupIntervalMs = getInt(KafkaConfig.TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp)
+
 
   /** ********* Metric Configuration **************/
   val metricNumSamples = getInt(KafkaConfig.MetricNumSamplesProp)

http://git-wip-us.apache.org/repos/asf/kafka/blob/20e20087/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index e225588..4d953eb 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -519,14 +519,15 @@ class TransactionCoordinatorTest {
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, Ongoing,
       partitions, now, now)
 
-    EasyMock.expect(transactionManager.transactionsToExpire())
+
+    EasyMock.expect(transactionManager.timedOutTransactions())
       .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
     EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject()))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
       .once()
 
     val expectedTransition = TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, PrepareAbort,
-      partitions.toSet, now, now + TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
+      partitions.toSet, now, now + TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs)
 
     EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
       EasyMock.eq(coordinatorEpoch),
@@ -540,7 +541,7 @@ class TransactionCoordinatorTest {
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
     coordinator.startup(false)
-    time.sleep(TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
+    time.sleep(TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs)
     scheduler.tick()
     EasyMock.verify(transactionManager)
   }
@@ -551,7 +552,7 @@ class TransactionCoordinatorTest {
       partitions, time.milliseconds(), time.milliseconds())
     metadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds())
 
-    EasyMock.expect(transactionManager.transactionsToExpire())
+    EasyMock.expect(transactionManager.timedOutTransactions())
       .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
     EasyMock.expect(transactionManager.getAndMaybeAddTransactionState(EasyMock.eq(transactionalId), EasyMock.anyObject[Option[TransactionMetadata]]()))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
@@ -559,7 +560,7 @@ class TransactionCoordinatorTest {
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
     coordinator.startup(false)
-    time.sleep(TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
+    time.sleep(TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs)
     scheduler.tick()
     EasyMock.verify(transactionManager)
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/20e20087/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index 8682026..479f99b 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -316,7 +316,7 @@ class TransactionStateManagerTest {
   }
 
   @Test
-  def shouldOnlyConsiderTransactionsInTheOngoingStateForExpiry(): Unit = {
+  def shouldOnlyConsiderTransactionsInTheOngoingStateToAbort(): Unit = {
     for (partitionId <- 0 until numPartitions) {
       transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
     }
@@ -329,7 +329,7 @@ class TransactionStateManagerTest {
     transactionManager.getAndMaybeAddTransactionState("complete-abort", Some(transactionMetadata("complete-abort", producerId = 5, state = CompleteAbort)))
 
     time.sleep(2000)
-    val expiring = transactionManager.transactionsToExpire()
+    val expiring = transactionManager.timedOutTransactions()
     assertEquals(List(TransactionalIdAndProducerIdEpoch("ongoing", 0, 0)), expiring)
   }
 
@@ -343,6 +343,121 @@ class TransactionStateManagerTest {
     verifyWritesTxnMarkersInPrepareState(PrepareAbort)
   }
 
+  @Test
+  def shouldRemoveCompleteCommmitExpiredTransactionalIds(): Unit = {
+    setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteCommit)
+    verifyMetadataDoesntExist(transactionalId1)
+    verifyMetadataDoesExist(transactionalId2)
+  }
+
+  @Test
+  def shouldRemoveCompleteAbortExpiredTransactionalIds(): Unit = {
+    setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteAbort)
+    verifyMetadataDoesntExist(transactionalId1)
+    verifyMetadataDoesExist(transactionalId2)
+  }
+
+  @Test
+  def shouldRemoveEmptyExpiredTransactionalIds(): Unit = {
+    setupAndRunTransactionalIdExpiration(Errors.NONE, Empty)
+    verifyMetadataDoesntExist(transactionalId1)
+    verifyMetadataDoesExist(transactionalId2)
+  }
+
+  @Test
+  def shouldNotRemoveExpiredTransactionalIdsIfLogAppendFails(): Unit = {
+    setupAndRunTransactionalIdExpiration(Errors.NOT_ENOUGH_REPLICAS, CompleteAbort)
+    verifyMetadataDoesExist(transactionalId1)
+    verifyMetadataDoesExist(transactionalId2)
+  }
+
+  @Test
+  def shouldNotRemoveOngoingTransactionalIds(): Unit = {
+    setupAndRunTransactionalIdExpiration(Errors.NONE, Ongoing)
+    verifyMetadataDoesExist(transactionalId1)
+    verifyMetadataDoesExist(transactionalId2)
+  }
+
+  @Test
+  def shouldNotRemovePrepareAbortTransactionalIds(): Unit = {
+    setupAndRunTransactionalIdExpiration(Errors.NONE, PrepareAbort)
+    verifyMetadataDoesExist(transactionalId1)
+    verifyMetadataDoesExist(transactionalId2)
+  }
+
+  @Test
+  def shouldNotRemovePrepareCommitTransactionalIds(): Unit = {
+    setupAndRunTransactionalIdExpiration(Errors.NONE, PrepareCommit)
+    verifyMetadataDoesExist(transactionalId1)
+    verifyMetadataDoesExist(transactionalId2)
+  }
+
+  private def verifyMetadataDoesExist(transactionalId: String) = {
+    transactionManager.getAndMaybeAddTransactionState(transactionalId, None) match {
+      case Left(errors) => fail("shouldn't have been any errors")
+      case Right(None) => fail("metadata should have been removed")
+      case Right(Some(metadata)) => // ok
+    }
+  }
+
+  private def verifyMetadataDoesntExist(transactionalId: String) = {
+    transactionManager.getAndMaybeAddTransactionState(transactionalId, None) match {
+      case Left(errors) => fail("shouldn't have been any errors")
+      case Right(Some(metdata)) => fail("metadata should have been removed")
+      case Right(None) => // ok
+    }
+  }
+
+  private def setupAndRunTransactionalIdExpiration(error: Errors, txnState: TransactionState) = {
+    for (partitionId <- 0 until numPartitions) {
+      transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
+    }
+
+    val capturedArgument: Capture[Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture()
+
+    val partition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, transactionManager.partitionFor(transactionalId1))
+    val recordsByPartition = Map(partition -> MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType,
+      new SimpleRecord(time.milliseconds() + txnConfig.removeExpiredTransactionalIdsIntervalMs, TransactionLog.keyToBytes(transactionalId1), null)))
+
+    txnState match {
+      case Empty | CompleteCommit | CompleteAbort =>
+
+        EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(),
+          EasyMock.eq((-1).toShort),
+          EasyMock.eq(true),
+          EasyMock.eq(false),
+          EasyMock.eq(recordsByPartition),
+          EasyMock.capture(capturedArgument),
+          EasyMock.eq(None)
+        )).andAnswer(new IAnswer[Unit] {
+          override def answer(): Unit = {
+            capturedArgument.getValue.apply(
+              Map(partition ->
+                new PartitionResponse(error, 0L, RecordBatch.NO_TIMESTAMP)
+              )
+            )
+          }
+        })
+      case _ => // shouldn't append
+    }
+
+    EasyMock.replay(replicaManager)
+
+    txnMetadata1.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs
+    txnMetadata1.state = txnState
+    transactionManager.getAndMaybeAddTransactionState(transactionalId1, Some(txnMetadata1))
+
+    txnMetadata2.txnLastUpdateTimestamp = time.milliseconds()
+    transactionManager.getAndMaybeAddTransactionState(transactionalId2, Some(txnMetadata2))
+
+    transactionManager.enableTransactionalIdExpiration()
+    time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs)
+
+    scheduler.tick()
+
+    EasyMock.verify(replicaManager)
+  }
+
   private def verifyWritesTxnMarkersInPrepareState(state: TransactionState): Unit = {
     txnMetadata1.state = state
     txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),