You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2017/10/27 06:26:36 UTC

kafka git commit: KAFKA-6119: Bump epoch when expiring transactions in the TransactionCoordinator

Repository: kafka
Updated Branches:
  refs/heads/trunk 69e8463c0 -> 501a5e262


KAFKA-6119: Bump epoch when expiring transactions in the TransactionCoordinator

A description of the problem is in the JIRA. I have added an integration test which reproduces the original scenario, and also added unit test cases.

Author: Apurva Mehta <ap...@confluent.io>

Reviewers: Jason Gustafson <ja...@confluent.io>, Ted Yu <yu...@gmail.com>, Guozhang Wang <wa...@gmail.com>

Closes #4137 from apurvam/KAFKA-6119-bump-epoch-when-expiring-transactions


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/501a5e26
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/501a5e26
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/501a5e26

Branch: refs/heads/trunk
Commit: 501a5e262702bcc043724cb9e1f536e16a66399e
Parents: 69e8463
Author: Apurva Mehta <ap...@confluent.io>
Authored: Thu Oct 26 23:26:33 2017 -0700
Committer: Guozhang Wang <wa...@gmail.com>
Committed: Thu Oct 26 23:26:33 2017 -0700

----------------------------------------------------------------------
 .../transaction/TransactionCoordinator.scala    | 89 +++++++++++++++-----
 .../transaction/TransactionMetadata.scala       | 28 ++++--
 .../scala/kafka/tools/DumpLogSegments.scala     |  2 +-
 .../kafka/api/TransactionsTest.scala            | 46 +++++++++-
 .../TransactionCoordinatorTest.scala            | 10 ++-
 .../transaction/TransactionMetadataTest.scala   |  8 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala | 14 ++-
 7 files changed, 156 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/501a5e26/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index 0b38dbc..b307a39 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -60,7 +60,7 @@ object TransactionCoordinator {
     val txnMarkerChannelManager = TransactionMarkerChannelManager(config, metrics, metadataCache, txnStateManager,
       txnMarkerPurgatory, time, logContext)
 
-    new TransactionCoordinator(config.brokerId, scheduler, producerIdManager, txnStateManager, txnMarkerChannelManager,
+    new TransactionCoordinator(config.brokerId, txnConfig, scheduler, producerIdManager, txnStateManager, txnMarkerChannelManager,
       time, logContext)
   }
 
@@ -82,6 +82,7 @@ object TransactionCoordinator {
  * Producers with no specific transactional id may talk to a random broker as their coordinators.
  */
 class TransactionCoordinator(brokerId: Int,
+                             txnConfig: TransactionConfig,
                              scheduler: Scheduler,
                              producerIdManager: ProducerIdManager,
                              txnManager: TransactionStateManager,
@@ -147,7 +148,7 @@ class TransactionCoordinator(brokerId: Int,
           responseCallback(initTransactionError(error))
 
         case Right((coordinatorEpoch, newMetadata)) =>
-          if (newMetadata.txnState == Ongoing) {
+          if (newMetadata.txnState == PrepareEpochFence) {
             // abort the ongoing transaction and then return CONCURRENT_TRANSACTIONS to let client wait and retry
             def sendRetriableErrorCallback(error: Errors): Unit = {
               if (error != Errors.NONE) {
@@ -212,7 +213,8 @@ class TransactionCoordinator(brokerId: Int,
           // particular, if fencing the current producer exhausts the available epochs for the current producerId,
           // then when the client retries, we will generate a new producerId.
           Right(coordinatorEpoch, txnMetadata.prepareFenceProducerEpoch())
-        case Dead =>
+
+        case Dead | PrepareEpochFence =>
           val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " +
             s"This is illegal as we should never have transitioned to this state."
           fatal(errorMsg)
@@ -307,9 +309,9 @@ class TransactionCoordinator(brokerId: Int,
           txnMetadata.inLock {
             if (txnMetadata.producerId != producerId)
               Left(Errors.INVALID_PRODUCER_ID_MAPPING)
-            else if (txnMetadata.producerEpoch != producerEpoch)
+            else if (producerEpoch < txnMetadata.producerEpoch)
               Left(Errors.INVALID_PRODUCER_EPOCH)
-            else if (txnMetadata.pendingTransitionInProgress)
+            else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence)
               Left(Errors.CONCURRENT_TRANSACTIONS)
             else txnMetadata.state match {
               case Ongoing =>
@@ -317,6 +319,15 @@ class TransactionCoordinator(brokerId: Int,
                   PrepareCommit
                 else
                   PrepareAbort
+
+                if (nextState == PrepareAbort && txnMetadata.pendingState.isDefined
+                  && txnMetadata.pendingState.get == PrepareEpochFence) {
+                  // We should clear the pending state to make way for the transition to PrepareAbort and also bump
+                  // the epoch in the transaction metadata we are about to append.
+                  txnMetadata.pendingState = None
+                  txnMetadata.producerEpoch = producerEpoch
+                }
+
                 Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, time.milliseconds()))
               case CompleteCommit =>
                 if (txnMarkerResult == TransactionResult.COMMIT)
@@ -340,7 +351,7 @@ class TransactionCoordinator(brokerId: Int,
                   logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
               case Empty =>
                 logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
-              case Dead =>
+              case Dead | PrepareEpochFence =>
                 val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " +
                   s"This is illegal as we should never have transitioned to this state."
                 fatal(errorMsg)
@@ -388,7 +399,7 @@ class TransactionCoordinator(brokerId: Int,
                             logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
                           else
                             Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
-                        case Dead =>
+                        case Dead | PrepareEpochFence =>
                           val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " +
                             s"This is illegal as we should never have transitioned to this state."
                           fatal(errorMsg)
@@ -434,20 +445,52 @@ class TransactionCoordinator(brokerId: Int,
 
   private def abortTimedOutTransactions(): Unit = {
     txnManager.timedOutTransactions().foreach { txnIdAndPidEpoch =>
-      handleEndTransaction(txnIdAndPidEpoch.transactionalId,
-        txnIdAndPidEpoch.producerId,
-        txnIdAndPidEpoch.producerEpoch,
-        TransactionResult.ABORT,
-        (error: Errors) => error match {
-          case Errors.NONE =>
-            info(s"Completed rollback ongoing transaction of transactionalId: ${txnIdAndPidEpoch.transactionalId} due to timeout")
-          case Errors.INVALID_PRODUCER_ID_MAPPING |
-               Errors.INVALID_PRODUCER_EPOCH |
-               Errors.CONCURRENT_TRANSACTIONS =>
-            debug(s"Rolling back ongoing transaction of transactionalId: ${txnIdAndPidEpoch.transactionalId} has aborted due to ${error.exceptionName()}")
-          case e =>
-            warn(s"Rolling back ongoing transaction of transactionalId: ${txnIdAndPidEpoch.transactionalId} failed due to ${error.exceptionName()}")
-        })
+      txnManager.getTransactionState(txnIdAndPidEpoch.transactionalId).right.flatMap {
+        case None =>
+          error(s"Could not find transaction metadata when trying to timeout transaction with transactionalId " +
+            s"${txnIdAndPidEpoch.transactionalId}. ProducerId: ${txnIdAndPidEpoch.producerId}. ProducerEpoch: " +
+            s"${txnIdAndPidEpoch.producerEpoch}")
+          Left(Errors.INVALID_TXN_STATE)
+
+        case Some(epochAndTxnMetadata) =>
+          val txnMetadata = epochAndTxnMetadata.transactionMetadata
+          val producerIdHasChanged = txnMetadata.inLock {
+            txnMetadata.producerId != txnIdAndPidEpoch.producerId
+          }
+          if (producerIdHasChanged) {
+            error(s"Found incorrect producerId when expiring transactionalId: ${txnIdAndPidEpoch.transactionalId}. " +
+              s"Expected producerId: ${txnIdAndPidEpoch.producerId}. Found producerId: " +
+              s"${epochAndTxnMetadata.transactionMetadata.producerId}")
+            Left(Errors.INVALID_PRODUCER_ID_MAPPING)
+          } else {
+            val transitMetadata: Either[Errors, TxnTransitMetadata] = txnMetadata.inLock {
+              if (txnMetadata.pendingTransitionInProgress)
+                Left(Errors.CONCURRENT_TRANSACTIONS)
+              else
+                Right(txnMetadata.prepareFenceProducerEpoch())
+            }
+            transitMetadata match {
+              case Right(txnTransitMetadata) =>
+                handleEndTransaction(txnMetadata.transactionalId,
+                  txnTransitMetadata.producerId,
+                  txnTransitMetadata.producerEpoch,
+                  TransactionResult.ABORT,
+                  {
+                    case Errors.NONE =>
+                      info(s"Completed rollback ongoing transaction of transactionalId: ${txnIdAndPidEpoch.transactionalId} due to timeout")
+                    case e @ (Errors.INVALID_PRODUCER_ID_MAPPING |
+                              Errors.INVALID_PRODUCER_EPOCH |
+                              Errors.CONCURRENT_TRANSACTIONS) =>
+                      debug(s"Rolling back ongoing transaction of transactionalId: ${txnIdAndPidEpoch.transactionalId} has aborted due to ${e.exceptionName}")
+                    case e =>
+                      warn(s"Rolling back ongoing transaction of transactionalId: ${txnIdAndPidEpoch.transactionalId} failed due to ${e.exceptionName}")
+                  })
+                Right(txnTransitMetadata)
+              case (error) =>
+                Left(error)
+            }
+         }
+      }
     }
   }
 
@@ -459,8 +502,8 @@ class TransactionCoordinator(brokerId: Int,
     scheduler.startup()
     scheduler.schedule("transaction-abort",
       abortTimedOutTransactions,
-      TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs,
-      TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs
+      txnConfig.abortTimedOutTransactionsIntervalMs,
+      txnConfig.abortTimedOutTransactionsIntervalMs
     )
     if (enableTransactionalIdExpiration)
       txnManager.enableTransactionalIdExpiration()

http://git-wip-us.apache.org/repos/asf/kafka/blob/501a5e26/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index 486a887..ea82fb5 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -77,6 +77,12 @@ private[transaction] case object CompleteAbort extends TransactionState { val by
   */
 private[transaction] case object Dead extends TransactionState { val byte: Byte = 6 }
 
+/**
+  * We are in the middle of bumping the epoch and fencing out older producers.
+  */
+
+private[transaction] case object PrepareEpochFence extends TransactionState { val byte: Byte = 7}
+
 private[transaction] object TransactionMetadata {
   def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, timestamp: Long) =
     new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, Empty,
@@ -96,6 +102,7 @@ private[transaction] object TransactionMetadata {
       case 4 => CompleteCommit
       case 5 => CompleteAbort
       case 6 => Dead
+      case 7 => PrepareEpochFence
       case unknown => throw new IllegalStateException("Unknown transaction state byte " + unknown + " from the transaction status message")
     }
   }
@@ -107,10 +114,12 @@ private[transaction] object TransactionMetadata {
     Map(Empty -> Set(Empty, CompleteCommit, CompleteAbort),
       Ongoing -> Set(Ongoing, Empty, CompleteCommit, CompleteAbort),
       PrepareCommit -> Set(Ongoing),
-      PrepareAbort -> Set(Ongoing),
+      PrepareAbort -> Set(Ongoing, PrepareEpochFence),
       CompleteCommit -> Set(PrepareCommit),
       CompleteAbort -> Set(PrepareAbort),
-      Dead -> Set(Empty, CompleteAbort, CompleteCommit))
+      Dead -> Set(Empty, CompleteAbort, CompleteCommit),
+      PrepareEpochFence -> Set(Ongoing)
+    )
 }
 
 // this is a immutable object representing the target transition of the transaction metadata
@@ -184,11 +193,8 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
     if (producerEpoch == Short.MaxValue)
       throw new IllegalStateException(s"Cannot fence producer with epoch equal to Short.MaxValue since this would overflow")
 
-    // bump up the epoch to let the txn markers be able to override the current producer epoch
-    producerEpoch = (producerEpoch + 1).toShort
-
-    // do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state
-    TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
+    prepareTransitionTo(PrepareEpochFence, producerId, (producerEpoch + 1).toShort, txnTimeoutMs, topicPartitions.toSet,
+      txnStartTimestamp, txnLastUpdateTimestamp)
   }
 
   def prepareIncrementProducerEpoch(newTxnTimeoutMs: Int, updateTimestamp: Long): TxnTransitMetadata = {
@@ -343,6 +349,14 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
             topicPartitions.clear()
           }
 
+        case PrepareEpochFence =>
+          // We should never get here, since once we prepare to fence the epoch, we immediately set the pending state
+          // to PrepareAbort, and then consequently to CompleteAbort after the markers are written.. So we should never
+          // ever try to complete a transition to PrepareEpochFence, as it is not a valid previous state for any other state, and hence
+          // can never be transitioned out of.
+          throwStateTransitionFailure(transitMetadata)
+
+
         case Dead =>
           // The transactionalId was being expired. The completion of the operation should result in removal of the
           // the metadata from the cache, so we should never realistically transition to the dead state.

http://git-wip-us.apache.org/repos/asf/kafka/blob/501a5e26/core/src/main/scala/kafka/tools/DumpLogSegments.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/tools/DumpLogSegments.scala b/core/src/main/scala/kafka/tools/DumpLogSegments.scala
index f0ea50c..fe82dc2 100755
--- a/core/src/main/scala/kafka/tools/DumpLogSegments.scala
+++ b/core/src/main/scala/kafka/tools/DumpLogSegments.scala
@@ -390,7 +390,7 @@ object DumpLogSegments {
             " compresscodec: " + batch.compressionType)
 
           if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) {
-            print(" producerId: " + batch.producerId + " sequence: " + record.sequence +
+            print(" producerId: " + batch.producerId + " producerEpoch: " + batch.producerEpoch + " sequence: " + record.sequence +
               " isTransactional: " + batch.isTransactional +
               " headerKeys: " + record.headers.map(_.key).mkString("[", ",", "]"))
           } else {

http://git-wip-us.apache.org/repos/asf/kafka/blob/501a5e26/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
index 0fa0b87..3eee7f1 100644
--- a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
+++ b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
@@ -19,7 +19,7 @@ package kafka.api
 
 import java.lang.{Long => JLong}
 import java.util.Properties
-import java.util.concurrent.TimeUnit
+import java.util.concurrent.{ExecutionException, TimeUnit}
 
 import kafka.integration.KafkaServerTestHarness
 import kafka.server.KafkaConfig
@@ -456,6 +456,44 @@ class TransactionsTest extends KafkaServerTestHarness {
   }
 
   @Test
+  def testFencingOnTransactionExpiration(): Unit = {
+    val producer = createTransactionalProducer("expiringProducer", transactionTimeoutMs = 100)
+
+    producer.initTransactions()
+    producer.beginTransaction()
+
+    // The first message and hence the first AddPartitions request should be successfully sent.
+    val firstMessageResult = producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, "1", "1", willBeCommitted = false)).get()
+    assertTrue(firstMessageResult.hasOffset)
+
+    // Wait for the expiration cycle to kick in.
+    Thread.sleep(600)
+
+    try {
+      // Now that the transaction has expired, the second send should fail with a ProducerFencedException.
+      producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, "2", "2", willBeCommitted = false)).get()
+      fail("should have raised a ProducerFencedException since the transaction has expired")
+    } catch {
+      case _: ProducerFencedException =>
+      case e: ExecutionException =>
+      assertTrue(e.getCause.isInstanceOf[ProducerFencedException])
+    }
+
+    // Verify that the first message was aborted and the second one was never written at all.
+    val nonTransactionalConsumer = nonTransactionalConsumers(0)
+    nonTransactionalConsumer.subscribe(List(topic1).asJava)
+    val records = TestUtils.consumeRemainingRecords(nonTransactionalConsumer, 1000)
+    assertEquals(1, records.size)
+    assertEquals("1", TestUtils.recordValueAsString(records.head))
+
+    val transactionalConsumer = transactionalConsumers.head
+    transactionalConsumer.subscribe(List(topic1).asJava)
+
+    val transactionalRecords = TestUtils.consumeRemainingRecords(transactionalConsumer, 1000)
+    assertTrue(transactionalRecords.isEmpty)
+  }
+
+  @Test
   def testMultipleMarkersOneLeader(): Unit = {
     val firstProducer = transactionalProducers.head
     val consumer = transactionalConsumers.head
@@ -515,6 +553,7 @@ class TransactionsTest extends KafkaServerTestHarness {
     serverProps.put(KafkaConfig.UncleanLeaderElectionEnableProp, false.toString)
     serverProps.put(KafkaConfig.AutoLeaderRebalanceEnableProp, false.toString)
     serverProps.put(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0")
+    serverProps.put(KafkaConfig.TransactionsAbortTimedOutTransactionCleanupIntervalMsProp, "200")
     serverProps
   }
 
@@ -539,8 +578,9 @@ class TransactionsTest extends KafkaServerTestHarness {
     consumer
   }
 
-  private def createTransactionalProducer(transactionalId: String): KafkaProducer[Array[Byte], Array[Byte]] = {
-    val producer = TestUtils.createTransactionalProducer(transactionalId, servers)
+  private def createTransactionalProducer(transactionalId: String, transactionTimeoutMs: Long = 60000): KafkaProducer[Array[Byte], Array[Byte]] = {
+    val producer = TestUtils.createTransactionalProducer(transactionalId, servers,
+      transactionTimeoutMs = transactionTimeoutMs)
     transactionalProducers += producer
     producer
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/501a5e26/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index dae52c8..75f06d5 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -48,6 +48,7 @@ class TransactionCoordinatorTest {
   private val scheduler = new MockScheduler(time)
 
   val coordinator = new TransactionCoordinator(brokerId,
+    new TransactionConfig(),
     scheduler,
     pidManager,
     transactionManager,
@@ -598,7 +599,7 @@ class TransactionCoordinatorTest {
   }
 
   @Test
-  def shouldAbortExpiredTransactionsInOngoingState(): Unit = {
+  def shouldAbortExpiredTransactionsInOngoingStateAndBumpEpoch(): Unit = {
     val now = time.milliseconds()
     val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, Ongoing,
       partitions, now, now)
@@ -608,9 +609,10 @@ class TransactionCoordinatorTest {
       .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
     EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
       .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
-      .once()
+      .times(2)
 
-    val expectedTransition = TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, PrepareAbort,
+    val bumpedEpoch = (producerEpoch + 1).toShort
+    val expectedTransition = TxnTransitMetadata(producerId, bumpedEpoch, txnTimeoutMs, PrepareAbort,
       partitions.toSet, now, now + TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs)
 
     EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
@@ -640,7 +642,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.timedOutTransactions())
       .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
     EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
-      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
+      .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))).once()
 
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/501a5e26/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala
index 4f2fe5f..3a4390c 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala
@@ -104,8 +104,12 @@ class TransactionMetadataTest {
       txnLastUpdateTimestamp = time.milliseconds())
     assertTrue(txnMetadata.isProducerEpochExhausted)
 
-    txnMetadata.prepareFenceProducerEpoch()
-    assertEquals(Short.MaxValue, txnMetadata.producerEpoch)
+    val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch()
+    assertEquals(Short.MaxValue, fencingTransitMetadata.producerEpoch)
+    assertEquals(Some(PrepareEpochFence), txnMetadata.pendingState)
+
+    // We should reset the pending state to make way for the abort transition.
+    txnMetadata.pendingState = None
 
     val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())
     txnMetadata.completeTransitionTo(transitMetadata)

http://git-wip-us.apache.org/repos/asf/kafka/blob/501a5e26/core/src/test/scala/unit/kafka/utils/TestUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index b8d0afb..974a493 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -1382,12 +1382,24 @@ object TestUtils extends Logging {
     records
   }
 
-  def createTransactionalProducer(transactionalId: String, servers: Seq[KafkaServer], batchSize: Int = 16384) = {
+  def consumeRemainingRecords[K, V](consumer: KafkaConsumer[K, V], timeout: Long): Seq[ConsumerRecord[K, V]] = {
+    val startTime = System.currentTimeMillis()
+    val records = new ArrayBuffer[ConsumerRecord[K, V]]()
+    waitUntilTrue(() => {
+      records ++= consumer.poll(50).asScala
+      System.currentTimeMillis() - startTime > timeout
+    }, s"The timeout $timeout was greater than the maximum wait time.")
+    records
+  }
+
+  def createTransactionalProducer(transactionalId: String, servers: Seq[KafkaServer], batchSize: Int = 16384,
+                                  transactionTimeoutMs: Long = 60000) = {
     val props = new Properties()
     props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, transactionalId)
     props.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, "5")
     props.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true")
     props.put(ProducerConfig.BATCH_SIZE_CONFIG, batchSize.toString)
+    props.put(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG, transactionTimeoutMs.toString)
     TestUtils.createNewProducer(TestUtils.getBrokerListStrFromServers(servers), retries = Integer.MAX_VALUE, acks = -1, props = Some(props))
   }