You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2017/05/23 16:57:03 UTC

kafka git commit: MINOR: Log transaction metadata state transitions plus a few cleanups

Repository: kafka
Updated Branches:
  refs/heads/trunk 5a6676bfc -> 70ec4b1d9


MINOR: Log transaction metadata state transitions plus a few cleanups

Author: Jason Gustafson <ja...@confluent.io>

Reviewers: Guozhang Wang <wa...@gmail.com>

Closes #3081 from hachikuji/minor-add-txn-transition-logging


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/70ec4b1d
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/70ec4b1d
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/70ec4b1d

Branch: refs/heads/trunk
Commit: 70ec4b1d927cad7373eda2c54ef44bfc4275832f
Parents: 5a6676b
Author: Jason Gustafson <ja...@confluent.io>
Authored: Tue May 23 09:53:18 2017 -0700
Committer: Jason Gustafson <ja...@confluent.io>
Committed: Tue May 23 09:53:18 2017 -0700

----------------------------------------------------------------------
 .../producer/internals/TransactionManager.java  |   4 +-
 .../transaction/TransactionCoordinator.scala    |  22 +--
 .../transaction/TransactionLog.scala            |  41 ++---
 .../TransactionMarkerChannelManager.scala       |   5 +-
 .../transaction/TransactionMetadata.scala       | 160 +++++++++++--------
 .../transaction/TransactionStateManager.scala   |  41 ++---
 .../TransactionCoordinatorTest.scala            | 119 ++++++++------
 .../transaction/TransactionLogTest.scala        |  42 +++--
 .../TransactionMarkerChannelManagerTest.scala   |  15 +-
 ...tionMarkerRequestCompletionHandlerTest.scala |   3 +-
 .../TransactionStateManagerTest.scala           | 133 +++++++--------
 11 files changed, 306 insertions(+), 279 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index c6787f2..d84a88e 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -407,10 +407,10 @@ public class TransactionManager {
     }
 
     private synchronized void transitionTo(State target, Exception error) {
-        if (target == State.ERROR && error != null)
-            lastError = error;
         if (currentState.isTransitionValid(currentState, target)) {
             currentState = target;
+            if (target == State.ERROR && error != null)
+                lastError = error;
         } else {
             throw new KafkaException("Invalid transition attempted from state " + currentState.name() +
                     " to state " + target.name());

http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index 491e16a..b58c710 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -59,7 +59,7 @@ object TransactionCoordinator {
     InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, error)
   }
 
-  private def initTransactionMetadata(txnMetadata: TransactionMetadataTransition): InitProducerIdResult = {
+  private def initTransactionMetadata(txnMetadata: TxnTransitMetadata): InitProducerIdResult = {
     InitProducerIdResult(txnMetadata.producerId, txnMetadata.producerEpoch, Errors.NONE)
   }
 }
@@ -113,18 +113,20 @@ class TransactionCoordinator(brokerId: Int,
       responseCallback(initTransactionError(Errors.INVALID_TRANSACTION_TIMEOUT))
     } else {
       // only try to get a new producerId and update the cache if the transactional id is unknown
-      val result: Either[InitProducerIdResult, (Int, TransactionMetadataTransition)] = txnManager.getTransactionState(transactionalId) match {
+      val result: Either[InitProducerIdResult, (Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId) match {
         case None =>
           val producerId = producerIdManager.generateProducerId()
           val now = time.milliseconds()
-          val createdMetadata = new TransactionMetadata(producerId = producerId,
+          val createdMetadata = new TransactionMetadata(
+            transactionalId = transactionalId,
+            producerId = producerId,
             producerEpoch = 0,
             txnTimeoutMs = transactionTimeoutMs,
             state = Empty,
             topicPartitions = collection.mutable.Set.empty[TopicPartition],
             txnLastUpdateTimestamp = now)
 
-          val epochAndMetadata = txnManager.addTransaction(transactionalId, createdMetadata)
+          val epochAndMetadata = txnManager.addTransaction(createdMetadata)
           val coordinatorEpoch = epochAndMetadata.coordinatorEpoch
           val txnMetadata = epochAndMetadata.transactionMetadata
 
@@ -135,7 +137,7 @@ class TransactionCoordinator(brokerId: Int,
             if (!txnMetadata.eq(createdMetadata)) {
               initProducerIdWithExistingMetadata(transactionalId, transactionTimeoutMs, coordinatorEpoch, txnMetadata)
             } else {
-              Right(coordinatorEpoch, txnMetadata.prepareNewPid(time.milliseconds()))
+              Right(coordinatorEpoch, txnMetadata.prepareNewProducerId(time.milliseconds()))
             }
           }
 
@@ -185,7 +187,7 @@ class TransactionCoordinator(brokerId: Int,
   private def initProducerIdWithExistingMetadata(transactionalId: String,
                                                  transactionTimeoutMs: Int,
                                                  coordinatorEpoch: Int,
-                                                 txnMetadata: TransactionMetadata): Either[InitProducerIdResult, (Int, TransactionMetadataTransition)] = {
+                                                 txnMetadata: TransactionMetadata): Either[InitProducerIdResult, (Int, TxnTransitMetadata)] = {
     if (txnMetadata.pendingTransitionInProgress) {
       // return a retriable exception to let the client backoff and retry
       Left(initTransactionError(Errors.CONCURRENT_TRANSACTIONS))
@@ -229,7 +231,7 @@ class TransactionCoordinator(brokerId: Int,
     } else {
       // try to update the transaction metadata and append the updated metadata to txn log;
       // if there is no such metadata treat it as invalid producerId mapping error.
-      val result: Either[Errors, (Int, TransactionMetadataTransition)] = txnManager.getTransactionState(transactionalId) match {
+      val result: Either[Errors, (Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId) match {
         case None =>
           Left(Errors.INVALID_PRODUCER_ID_MAPPING)
 
@@ -293,7 +295,7 @@ class TransactionCoordinator(brokerId: Int,
     if (error != Errors.NONE)
       responseCallback(error)
     else {
-      val preAppendResult: Either[Errors, (Int, TransactionMetadataTransition)] = txnManager.getTransactionState(transactionalId) match {
+      val preAppendResult: Either[Errors, (Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId) match {
         case None =>
           Left(Errors.INVALID_PRODUCER_ID_MAPPING)
 
@@ -348,7 +350,7 @@ class TransactionCoordinator(brokerId: Int,
         case Right((coordinatorEpoch, newMetadata)) =>
           def sendTxnMarkersCallback(error: Errors): Unit = {
             if (error == Errors.NONE) {
-              val preSendResult: Either[Errors, (TransactionMetadata, TransactionMetadataTransition)] = txnManager.getTransactionState(transactionalId) match {
+              val preSendResult: Either[Errors, (TransactionMetadata, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId) match {
                 case Some(epochAndMetadata) =>
                   if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) {
 
@@ -447,7 +449,7 @@ class TransactionCoordinator(brokerId: Int,
       TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs
     )
     if (enablePidExpiration)
-      txnManager.enablePidExpiration()
+      txnManager.enableProducerIdExpiration()
     txnMarkerChannelManager.start()
     isActive.set(true)
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
index d0c9e87..efd315b 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
@@ -146,7 +146,7 @@ object TransactionLog {
     *
     * @return value payload bytes
     */
-  private[coordinator] def valueToBytes(txnMetadata: TransactionMetadataTransition): Array[Byte] = {
+  private[coordinator] def valueToBytes(txnMetadata: TxnTransitMetadata): Array[Byte] = {
     import ValueSchema._
     val value = new Struct(Current)
     value.set(ProducerIdField, txnMetadata.producerId)
@@ -168,10 +168,8 @@ object TransactionLog {
       val partitionArray = topicAndPartitions.map { case(topic, partitions) =>
         val topicPartitionsStruct = value.instance(TxnPartitionsField)
         val partitionIds: Array[Integer] = partitions.map(topicPartition => Integer.valueOf(topicPartition.partition())).toArray
-
         topicPartitionsStruct.set(PartitionsTopicField, topic)
         topicPartitionsStruct.set(PartitionIdsField, partitionIds)
-
         topicPartitionsStruct
       }
       value.set(TxnPartitionsField, partitionArray.toArray)
@@ -188,14 +186,13 @@ object TransactionLog {
     *
     * @return the key
     */
-  def readMessageKey(buffer: ByteBuffer): BaseKey = {
+  def readTxnRecordKey(buffer: ByteBuffer): TxnKey = {
     val version = buffer.getShort
     val keySchema = schemaForKey(version)
     val key = keySchema.read(buffer)
 
     if (version == KeySchema.CURRENT_VERSION) {
       val transactionalId = key.getString(KeySchema.TXN_ID_FIELD)
-
       TxnKey(version, transactionalId)
     } else {
       throw new IllegalStateException(s"Unknown version $version from the transaction log message")
@@ -207,7 +204,7 @@ object TransactionLog {
     *
     * @return a transaction metadata object from the message
     */
-  def readMessageValue(buffer: ByteBuffer): TransactionMetadata = {
+  def readTxnRecordValue(transactionalId: String, buffer: ByteBuffer): TransactionMetadata = {
     if (buffer == null) { // tombstone
       null
     } else {
@@ -226,7 +223,8 @@ object TransactionLog {
         val entryTimestamp = value.getLong(TxnEntryTimestampField)
         val startTimestamp = value.getLong(TxnStartTimestampField)
 
-        val transactionMetadata = new TransactionMetadata(producerId, epoch, timeout, state, mutable.Set.empty[TopicPartition],startTimestamp, entryTimestamp)
+        val transactionMetadata = new TransactionMetadata(transactionalId, producerId, epoch, timeout, state,
+          mutable.Set.empty[TopicPartition],startTimestamp, entryTimestamp)
 
         if (!state.equals(Empty)) {
           val topicPartitionArray = value.getArray(TxnPartitionsField)
@@ -255,28 +253,21 @@ object TransactionLog {
   // Formatter for use with tools to read transaction log messages
   class TransactionLogMessageFormatter extends MessageFormatter {
     def writeTo(consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]], output: PrintStream) {
-      Option(consumerRecord.key).map(key => readMessageKey(ByteBuffer.wrap(key))).foreach {
-        case txnKey: TxnKey =>
-          val transactionalId = txnKey.transactionalId
-          val value = consumerRecord.value
-          val producerIdMetadata =
-            if (value == null) "NULL"
-            else readMessageValue(ByteBuffer.wrap(value))
-          output.write(transactionalId.getBytes(StandardCharsets.UTF_8))
-          output.write("::".getBytes(StandardCharsets.UTF_8))
-          output.write(producerIdMetadata.toString.getBytes(StandardCharsets.UTF_8))
-          output.write("\n".getBytes(StandardCharsets.UTF_8))
-        case _ => // no-op
+      Option(consumerRecord.key).map(key => readTxnRecordKey(ByteBuffer.wrap(key))).foreach { txnKey =>
+        val transactionalId = txnKey.transactionalId
+        val value = consumerRecord.value
+        val producerIdMetadata =
+          if (value == null) "NULL"
+          else readTxnRecordValue(transactionalId, ByteBuffer.wrap(value))
+        output.write(transactionalId.getBytes(StandardCharsets.UTF_8))
+        output.write("::".getBytes(StandardCharsets.UTF_8))
+        output.write(producerIdMetadata.toString.getBytes(StandardCharsets.UTF_8))
+        output.write("\n".getBytes(StandardCharsets.UTF_8))
       }
     }
   }
 }
 
-sealed trait BaseKey {
-  def version: Short
-  def transactionalId: Any
-}
-
-case class TxnKey(version: Short, transactionalId: String) extends BaseKey {
+case class TxnKey(version: Short, transactionalId: String) {
   override def toString: String = transactionalId.toString
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
index 7c42574..b25e82d 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
@@ -85,9 +85,6 @@ object TransactionMarkerChannelManager {
       time)
   }
 
-  private[transaction] def requestGenerator(transactionMarkerChannelManager: TransactionMarkerChannelManager): () => Iterable[RequestAndCompletionHandler] = {
-    () => transactionMarkerChannelManager.drainQueuedTransactionMarkers()
-  }
 }
 
 class TxnMarkerQueue(@volatile private var destination: Node) {
@@ -191,7 +188,7 @@ class TransactionMarkerChannelManager(config: KafkaConfig,
                           coordinatorEpoch: Int,
                           txnResult: TransactionResult,
                           txnMetadata: TransactionMetadata,
-                          newMetadata: TransactionMetadataTransition): Unit = {
+                          newMetadata: TxnTransitMetadata): Unit = {
 
     def appendToLogCallback(error: Errors): Unit = {
       error match {

http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index 6e29308..fef395a 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -16,7 +16,7 @@
  */
 package kafka.coordinator.transaction
 
-import kafka.utils.nonthreadsafe
+import kafka.utils.{Logging, nonthreadsafe}
 import org.apache.kafka.common.TopicPartition
 
 import scala.collection.{immutable, mutable}
@@ -70,9 +70,14 @@ private[transaction] case object CompleteCommit extends TransactionState { val b
 private[transaction] case object CompleteAbort extends TransactionState { val byte: Byte = 5 }
 
 private[transaction] object TransactionMetadata {
-  def apply(producerId: Long, epoch: Short, txnTimeoutMs: Int, timestamp: Long) = new TransactionMetadata(producerId, epoch, txnTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
+  def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, timestamp: Long) =
+    new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, Empty,
+      collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
 
-  def apply(producerId: Long, epoch: Short, txnTimeoutMs: Int, state: TransactionState, timestamp: Long) = new TransactionMetadata(producerId, epoch, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
+  def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int,
+            state: TransactionState, timestamp: Long) =
+    new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, state,
+      collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
 
   def byteToState(byte: Byte): TransactionState = {
     byte match {
@@ -86,7 +91,8 @@ private[transaction] object TransactionMetadata {
     }
   }
 
-  def isValidTransition(oldState: TransactionState, newState: TransactionState): Boolean = TransactionMetadata.validPreviousStates(newState).contains(oldState)
+  def isValidTransition(oldState: TransactionState, newState: TransactionState): Boolean =
+    TransactionMetadata.validPreviousStates(newState).contains(oldState)
 
   private val validPreviousStates: Map[TransactionState, Set[TransactionState]] =
     Map(Empty -> Set(Empty, CompleteCommit, CompleteAbort),
@@ -98,13 +104,24 @@ private[transaction] object TransactionMetadata {
 }
 
 // this is a immutable object representing the target transition of the transaction metadata
-private[transaction] case class TransactionMetadataTransition(producerId: Long,
-                                                              producerEpoch: Short,
-                                                              txnTimeoutMs: Int,
-                                                              txnState: TransactionState,
-                                                              topicPartitions: immutable.Set[TopicPartition],
-                                                              txnStartTimestamp: Long,
-                                                              txnLastUpdateTimestamp: Long)
+private[transaction] case class TxnTransitMetadata(producerId: Long,
+                                                   producerEpoch: Short,
+                                                   txnTimeoutMs: Int,
+                                                   txnState: TransactionState,
+                                                   topicPartitions: immutable.Set[TopicPartition],
+                                                   txnStartTimestamp: Long,
+                                                   txnLastUpdateTimestamp: Long) {
+  override def toString: String = {
+    "TxnTransitMetadata(" +
+      s"producerId=$producerId, " +
+      s"producerEpoch=$producerEpoch, " +
+      s"txnTimeoutMs=$txnTimeoutMs, " +
+      s"txnState=$txnState, " +
+      s"topicPartitions=$topicPartitions, " +
+      s"txnStartTimestamp=$txnStartTimestamp, " +
+      s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp)"
+  }
+}
 
 /**
   *
@@ -117,13 +134,14 @@ private[transaction] case class TransactionMetadataTransition(producerId: Long,
   * @param txnLastUpdateTimestamp   updated when any operation updates the TransactionMetadata. To be used for expiration
   */
 @nonthreadsafe
-private[transaction] class TransactionMetadata(val producerId: Long,
+private[transaction] class TransactionMetadata(val transactionalId: String,
+                                               val producerId: Long,
                                                var producerEpoch: Short,
                                                var txnTimeoutMs: Int,
                                                var state: TransactionState,
                                                val topicPartitions: mutable.Set[TopicPartition],
                                                var txnStartTimestamp: Long = -1,
-                                               var txnLastUpdateTimestamp: Long) {
+                                               var txnLastUpdateTimestamp: Long) extends Logging {
 
   // pending state is used to indicate the state that this transaction is going to
   // transit to, and for blocking future attempts to transit it again if it is not legal;
@@ -143,82 +161,76 @@ private[transaction] class TransactionMetadata(val producerId: Long,
   }
 
   // this is visible for test only
-  def prepareNoTransit(): TransactionMetadataTransition = {
+  def prepareNoTransit(): TxnTransitMetadata = {
     // do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state
-    TransactionMetadataTransition(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
+    TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
   }
 
-  def prepareFenceProducerEpoch(): TransactionMetadataTransition = {
+  def prepareFenceProducerEpoch(): TxnTransitMetadata = {
     // bump up the epoch to let the txn markers be able to override the current producer epoch
     producerEpoch = (producerEpoch + 1).toShort
 
     // do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state
-    TransactionMetadataTransition(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
+    TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
   }
 
   def prepareIncrementProducerEpoch(newTxnTimeoutMs: Int,
-                                    updateTimestamp: Long): TransactionMetadataTransition = {
+                                    updateTimestamp: Long): TxnTransitMetadata = {
 
-    prepareTransitionTo(Empty, (producerEpoch + 1).toShort, newTxnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp)
+    prepareTransitionTo(Empty, (producerEpoch + 1).toShort, newTxnTimeoutMs, immutable.Set.empty[TopicPartition],
+      -1, updateTimestamp)
   }
 
-  def prepareNewPid(updateTimestamp: Long): TransactionMetadataTransition = {
-
+  def prepareNewProducerId(updateTimestamp: Long): TxnTransitMetadata = {
     prepareTransitionTo(Empty, producerEpoch, txnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp)
   }
 
   def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition],
-                           updateTimestamp: Long): TransactionMetadataTransition = {
+                           updateTimestamp: Long): TxnTransitMetadata = {
 
     if (state == Empty || state == CompleteCommit || state == CompleteAbort) {
-      prepareTransitionTo(Ongoing, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet, updateTimestamp, updateTimestamp)
+      prepareTransitionTo(Ongoing, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet,
+        updateTimestamp, updateTimestamp)
     } else {
-      prepareTransitionTo(Ongoing, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet, txnStartTimestamp, updateTimestamp)
+      prepareTransitionTo(Ongoing, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet,
+        txnStartTimestamp, updateTimestamp)
     }
   }
 
   def prepareAbortOrCommit(newState: TransactionState,
-                           updateTimestamp: Long): TransactionMetadataTransition = {
-
+                           updateTimestamp: Long): TxnTransitMetadata = {
     prepareTransitionTo(newState, producerEpoch, txnTimeoutMs, topicPartitions.toSet, txnStartTimestamp, updateTimestamp)
   }
 
-  def prepareComplete(updateTimestamp: Long): TransactionMetadataTransition = {
+  def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = {
     val newState = if (state == PrepareCommit) CompleteCommit else CompleteAbort
     prepareTransitionTo(newState, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition], txnStartTimestamp, updateTimestamp)
   }
 
-  // visible for testing only
-  def copy(): TransactionMetadata = {
-    val cloned = new TransactionMetadata(producerId, producerEpoch, txnTimeoutMs, state,
-      mutable.Set.empty ++ topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
-    cloned.pendingState = pendingState
-
-    cloned
-  }
-
   private def prepareTransitionTo(newState: TransactionState,
                                   newEpoch: Short,
                                   newTxnTimeoutMs: Int,
                                   newTopicPartitions: immutable.Set[TopicPartition],
                                   newTxnStartTimestamp: Long,
-                                  updateTimestamp: Long): TransactionMetadataTransition = {
+                                  updateTimestamp: Long): TxnTransitMetadata = {
     if (pendingState.isDefined)
       throw new IllegalStateException(s"Preparing transaction state transition to $newState " +
         s"while it already a pending state ${pendingState.get}")
 
     // check that the new state transition is valid and update the pending state if necessary
     if (TransactionMetadata.validPreviousStates(newState).contains(state)) {
+      val transitMetadata = TxnTransitMetadata(producerId, newEpoch, newTxnTimeoutMs, newState,
+        newTopicPartitions, newTxnStartTimestamp, updateTimestamp)
+      debug(s"TransactionalId $transactionalId prepare transition from $state to $transitMetadata")
       pendingState = Some(newState)
-
-      TransactionMetadataTransition(producerId, newEpoch, newTxnTimeoutMs, newState, newTopicPartitions, newTxnStartTimestamp, updateTimestamp)
+      transitMetadata
     } else {
       throw new IllegalStateException(s"Preparing transaction state transition to $newState failed since the target state" +
         s" $newState is not a valid previous state of the current state $state")
     }
   }
 
-  def completeTransitionTo(newMetadata: TransactionMetadataTransition): Unit = {
+  def completeTransitionTo(transitMetadata: TxnTransitMetadata): Unit = {
     // metadata transition is valid only if all the following conditions are met:
     //
     // 1. the new state is already indicated in the pending state.
@@ -233,59 +245,60 @@ private[transaction] class TransactionMetadata(val producerId: Long,
 
     val toState = pendingState.getOrElse(throw new IllegalStateException("Completing transaction state transition while it does not have a pending state"))
 
-    if (toState != newMetadata.txnState ||
-      producerId != newMetadata.producerId ||
-      txnLastUpdateTimestamp > newMetadata.txnLastUpdateTimestamp) {
+    if (toState != transitMetadata.txnState ||
+      producerId != transitMetadata.producerId ||
+      txnLastUpdateTimestamp > transitMetadata.txnLastUpdateTimestamp) {
 
       throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata state")
     } else {
-      val updated = toState match {
+      toState match {
         case Empty => // from initPid
-          if (producerEpoch > newMetadata.producerEpoch ||
-            producerEpoch < newMetadata.producerEpoch - 1 ||
-            newMetadata.topicPartitions.nonEmpty ||
-            newMetadata.txnStartTimestamp != -1) {
+          if (producerEpoch > transitMetadata.producerEpoch ||
+            producerEpoch < transitMetadata.producerEpoch - 1 ||
+            transitMetadata.topicPartitions.nonEmpty ||
+            transitMetadata.txnStartTimestamp != -1) {
 
             throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
           } else {
-            txnTimeoutMs = newMetadata.txnTimeoutMs
-            producerEpoch = newMetadata.producerEpoch
+            txnTimeoutMs = transitMetadata.txnTimeoutMs
+            producerEpoch = transitMetadata.producerEpoch
           }
 
         case Ongoing => // from addPartitions
-          if (producerEpoch != newMetadata.producerEpoch ||
-            !topicPartitions.subsetOf(newMetadata.topicPartitions) ||
-            txnTimeoutMs != newMetadata.txnTimeoutMs ||
-            txnStartTimestamp > newMetadata.txnStartTimestamp) {
+          if (producerEpoch != transitMetadata.producerEpoch ||
+            !topicPartitions.subsetOf(transitMetadata.topicPartitions) ||
+            txnTimeoutMs != transitMetadata.txnTimeoutMs ||
+            txnStartTimestamp > transitMetadata.txnStartTimestamp) {
 
             throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
           } else {
-            txnStartTimestamp = newMetadata.txnStartTimestamp
-            addPartitions(newMetadata.topicPartitions)
+            txnStartTimestamp = transitMetadata.txnStartTimestamp
+            addPartitions(transitMetadata.topicPartitions)
           }
 
         case PrepareAbort | PrepareCommit => // from endTxn
-          if (producerEpoch != newMetadata.producerEpoch ||
-            !topicPartitions.toSet.equals(newMetadata.topicPartitions) ||
-            txnTimeoutMs != newMetadata.txnTimeoutMs ||
-            txnStartTimestamp != newMetadata.txnStartTimestamp) {
+          if (producerEpoch != transitMetadata.producerEpoch ||
+            !topicPartitions.toSet.equals(transitMetadata.topicPartitions) ||
+            txnTimeoutMs != transitMetadata.txnTimeoutMs ||
+            txnStartTimestamp != transitMetadata.txnStartTimestamp) {
 
             throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
           }
 
         case CompleteAbort | CompleteCommit => // from write markers
-          if (producerEpoch != newMetadata.producerEpoch ||
-            txnTimeoutMs != newMetadata.txnTimeoutMs ||
-            newMetadata.txnStartTimestamp == -1) {
+          if (producerEpoch != transitMetadata.producerEpoch ||
+            txnTimeoutMs != transitMetadata.txnTimeoutMs ||
+            transitMetadata.txnStartTimestamp == -1) {
 
             throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
           } else {
-            txnStartTimestamp = newMetadata.txnStartTimestamp
+            txnStartTimestamp = transitMetadata.txnStartTimestamp
             topicPartitions.clear()
           }
       }
 
-      txnLastUpdateTimestamp = newMetadata.txnLastUpdateTimestamp
+      debug(s"TransactionalId $transactionalId complete transition from $state to $transitMetadata")
+      txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp
       pendingState = None
       state = toState
     }
@@ -293,10 +306,22 @@ private[transaction] class TransactionMetadata(val producerId: Long,
 
   def pendingTransitionInProgress: Boolean = pendingState.isDefined
 
-  override def toString = s"TransactionMetadata($pendingState, $producerId, $producerEpoch, $txnTimeoutMs, $state, $topicPartitions, $txnStartTimestamp, $txnLastUpdateTimestamp)"
+  override def toString = {
+    "TransactionMetadata(" +
+      s"transactionalId=$transactionalId, " +
+      s"producerId=$producerId, " +
+      s"producerEpoch=$producerEpoch, " +
+      s"txnTimeoutMs=$txnTimeoutMs, " +
+      s"state=$state, " +
+      s"pendingState=$pendingState, " +
+      s"topicPartitions=$topicPartitions, " +
+      s"txnStartTimestamp=$txnStartTimestamp, " +
+      s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp)"
+  }
 
   override def equals(that: Any): Boolean = that match {
     case other: TransactionMetadata =>
+      transactionalId == other.transactionalId &&
       producerId == other.producerId &&
       producerEpoch == other.producerEpoch &&
       txnTimeoutMs == other.txnTimeoutMs &&
@@ -308,7 +333,8 @@ private[transaction] class TransactionMetadata(val producerId: Long,
   }
 
   override def hashCode(): Int = {
-    val fields = Seq(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions, txnStartTimestamp, txnLastUpdateTimestamp)
+    val fields = Seq(transactionalId, producerId, producerEpoch, txnTimeoutMs, state, topicPartitions,
+      txnStartTimestamp, txnLastUpdateTimestamp)
     fields.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
   }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index 0952b5d..7d1a571 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -65,7 +65,7 @@ class TransactionStateManager(brokerId: Int,
 
   this.logIdent = "[Transaction Log Manager " + brokerId + "]: "
 
-  type SendTxnMarkersCallback = (String, Int, TransactionResult, TransactionMetadata, TransactionMetadataTransition) => Unit
+  type SendTxnMarkersCallback = (String, Int, TransactionResult, TransactionMetadata, TxnTransitMetadata) => Unit
 
   /** shutting down flag */
   private val shuttingDown = new AtomicBoolean(false)
@@ -108,7 +108,7 @@ class TransactionStateManager(brokerId: Int,
     }
   }
 
-  def enablePidExpiration() {
+  def enableProducerIdExpiration() {
     // TODO: add producer id expiration logic
   }
 
@@ -140,12 +140,13 @@ class TransactionStateManager(brokerId: Int,
    * Add a new transaction metadata, or retrieve the metadata if it already exists with the associated transactional id
    * along with the current coordinator epoch for that belonging transaction topic partition
    */
-  def addTransaction(transactionalId: String, txnMetadata: TransactionMetadata): CoordinatorEpochAndTxnMetadata = {
-    val partitionId = partitionFor(transactionalId)
+  def addTransaction(txnMetadata: TransactionMetadata): CoordinatorEpochAndTxnMetadata = {
+    val partitionId = partitionFor(txnMetadata.transactionalId)
 
     transactionMetadataCache.get(partitionId) match {
       case Some(txnMetadataCacheEntry) =>
-        val currentTxnMetadata = txnMetadataCacheEntry.metadataPerTransactionalId.putIfNotExists(transactionalId, txnMetadata)
+        val currentTxnMetadata = txnMetadataCacheEntry.metadataPerTransactionalId.putIfNotExists(
+          txnMetadata.transactionalId, txnMetadata)
         if (currentTxnMetadata != null) {
           CoordinatorEpochAndTxnMetadata(txnMetadataCacheEntry.coordinatorEpoch, currentTxnMetadata)
         } else {
@@ -242,23 +243,15 @@ class TransactionStateManager(brokerId: Int,
             memRecords.batches.asScala.foreach { batch =>
               for (record <- batch.asScala) {
                 require(record.hasKey, "Transaction state log's key should not be null")
-                TransactionLog.readMessageKey(record.key) match {
-
-                  case txnKey: TxnKey =>
-                    // load transaction metadata along with transaction state
-                    val transactionalId: String = txnKey.transactionalId
-                    if (!record.hasValue) {
-                      loadedTransactions.remove(transactionalId)
-                    } else {
-                      val txnMetadata = TransactionLog.readMessageValue(record.value)
-                      loadedTransactions.put(transactionalId, txnMetadata)
-                    }
-
-                  case unknownKey =>
-                    // TODO: Metrics
-                    throw new IllegalStateException(s"Unexpected message key $unknownKey while loading offsets and group metadata")
+                val txnKey = TransactionLog.readTxnRecordKey(record.key)
+                // load transaction metadata along with transaction state
+                val transactionalId = txnKey.transactionalId
+                if (!record.hasValue) {
+                  loadedTransactions.remove(transactionalId)
+                } else {
+                  val txnMetadata = TransactionLog.readTxnRecordValue(transactionalId, record.value)
+                  loadedTransactions.put(transactionalId, txnMetadata)
                 }
-
                 currOffset = batch.nextOffset
               }
             }
@@ -339,7 +332,7 @@ class TransactionStateManager(brokerId: Int,
       }
     }
 
-    scheduler.schedule(s"load-txns-for-partition-$topicPartition", loadTransactions _)
+    scheduler.schedule(s"load-txns-for-partition-$topicPartition", loadTransactions)
   }
 
   /**
@@ -374,7 +367,7 @@ class TransactionStateManager(brokerId: Int,
       }
     }
 
-    scheduler.schedule(s"remove-txns-for-partition-$topicPartition", removeTransactions _)
+    scheduler.schedule(s"remove-txns-for-partition-$topicPartition", removeTransactions)
   }
 
   private def validateTransactionTopicPartitionCountIsStable(): Unit = {
@@ -386,7 +379,7 @@ class TransactionStateManager(brokerId: Int,
   // TODO: check broker message format and error if < V2
   def appendTransactionToLog(transactionalId: String,
                              coordinatorEpoch: Int,
-                             newMetadata: TransactionMetadataTransition,
+                             newMetadata: TxnTransitMetadata,
                              responseCallback: Errors => Unit): Unit = {
 
     // generate the message for this transaction metadata

http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index 43ad7a7..3c94916 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -42,8 +42,8 @@ class TransactionCoordinatorTest {
   val brokerId = 0
   val coordinatorEpoch = 0
   private val transactionalId = "known"
-  private val pid = 10
-  private val epoch:Short = 1
+  private val producerId = 10
+  private val producerEpoch:Short = 1
   private val txnTimeoutMs = 1
 
   private val txnMarkerPurgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("test", new MockTimer, reaperEnabled = false)
@@ -122,7 +122,7 @@ class TransactionCoordinatorTest {
       })
       .once()
 
-    EasyMock.expect(transactionManager.addTransaction(EasyMock.eq(transactionalId), EasyMock.capture(capturedTxn)))
+    EasyMock.expect(transactionManager.addTransaction(EasyMock.capture(capturedTxn)))
       .andAnswer(new IAnswer[CoordinatorEpochAndTxnMetadata] {
         override def answer(): CoordinatorEpochAndTxnMetadata = {
           CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue)
@@ -133,7 +133,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
       EasyMock.eq(coordinatorEpoch),
-      EasyMock.anyObject().asInstanceOf[TransactionMetadataTransition],
+      EasyMock.anyObject().asInstanceOf[TxnTransitMetadata],
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {
@@ -213,7 +213,8 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(0, 0, 0, state, mutable.Set.empty, 0, 0))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
+        0, 0, 0, state, mutable.Set.empty, 0, 0))))
 
     EasyMock.replay(transactionManager)
 
@@ -226,7 +227,8 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(0, 10, 0, PrepareCommit, mutable.Set.empty, 0, 0))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
+        0, 10, 0, PrepareCommit, mutable.Set.empty, 0, 0))))
 
     EasyMock.replay(transactionManager)
 
@@ -255,7 +257,8 @@ class TransactionCoordinatorTest {
   }
 
   def validateSuccessfulAddPartitions(previousState: TransactionState): Unit = {
-    val txnMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds())
+    val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, previousState,
+      mutable.Set.empty, time.milliseconds(), time.milliseconds())
 
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
@@ -265,13 +268,13 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
       EasyMock.eq(coordinatorEpoch),
-      EasyMock.anyObject().asInstanceOf[TransactionMetadataTransition],
+      EasyMock.anyObject().asInstanceOf[TxnTransitMetadata],
       EasyMock.capture(capturedErrorsCallback)
     ))
 
     EasyMock.replay(transactionManager)
 
-    coordinator.handleAddPartitionsToTransaction(transactionalId, pid, epoch, partitions, errorsCallback)
+    coordinator.handleAddPartitionsToTransaction(transactionalId, producerId, producerEpoch, partitions, errorsCallback)
 
     EasyMock.verify(transactionManager)
   }
@@ -281,7 +284,8 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(0, 0, 0, Empty, partitions, 0, 0))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0,
+        0, Empty, partitions, 0, 0))))
 
     EasyMock.replay(transactionManager)
 
@@ -308,7 +312,8 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(10, 0, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 10, 0,
+        0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
@@ -321,10 +326,11 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
+        producerId, 1, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
-    coordinator.handleEndTransaction(transactionalId, pid, 0, TransactionResult.COMMIT, errorsCallback)
+    coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, errorsCallback)
     assertEquals(Errors.INVALID_PRODUCER_EPOCH, error)
     EasyMock.verify(transactionManager)
   }
@@ -334,10 +340,11 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
+        producerId, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
-    coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
+    coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
     assertEquals(Errors.NONE, error)
     EasyMock.verify(transactionManager)
   }
@@ -347,10 +354,11 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
+        producerId, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
-    coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.ABORT, errorsCallback)
+    coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback)
     assertEquals(Errors.NONE, error)
     EasyMock.verify(transactionManager)
   }
@@ -360,10 +368,11 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
+        producerId, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
-    coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
+    coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
     assertEquals(Errors.INVALID_TXN_STATE, error)
     EasyMock.verify(transactionManager)
   }
@@ -373,10 +382,11 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
+        producerId, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
-    coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.ABORT, errorsCallback)
+    coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback)
     assertEquals(Errors.INVALID_TXN_STATE, error)
     EasyMock.verify(transactionManager)
   }
@@ -386,10 +396,11 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
+        producerId, 1, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
-    coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
+    coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
     assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
     EasyMock.verify(transactionManager)
   }
@@ -399,10 +410,11 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId,
+        producerId, 1, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
-    coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
+    coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
     assertEquals(Errors.INVALID_TXN_STATE, error)
     EasyMock.verify(transactionManager)
   }
@@ -414,7 +426,7 @@ class TransactionCoordinatorTest {
 
     EasyMock.replay(transactionManager)
 
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
+    coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, errorsCallback)
 
     EasyMock.verify(transactionManager)
   }
@@ -425,7 +437,7 @@ class TransactionCoordinatorTest {
 
     EasyMock.replay(transactionManager)
 
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.ABORT, errorsCallback)
+    coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, errorsCallback)
     EasyMock.verify(transactionManager)
   }
 
@@ -487,7 +499,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): Unit = {
-    val txnMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, 0, 0)
+    val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, Ongoing,
+      partitions, 0, 0)
 
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
@@ -499,7 +512,8 @@ class TransactionCoordinatorTest {
       .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
       .anyTimes()
 
-    val originalMetadata = new TransactionMetadata(pid, (epoch + 1).toShort, txnTimeoutMs, Ongoing, partitions, 0, 0)
+    val originalMetadata = new TransactionMetadata(transactionalId, producerId, (producerEpoch + 1).toShort,
+      txnTimeoutMs, Ongoing, partitions, 0, 0)
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
       EasyMock.eq(coordinatorEpoch),
@@ -532,21 +546,24 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldAbortExpiredTransactionsInOngoingState(): Unit = {
-    val txnMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
+    val now = time.milliseconds()
+    val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, Ongoing,
+      partitions, now, now)
 
     EasyMock.expect(transactionManager.transactionsToExpire())
-      .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, pid, epoch)))
+      .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
       .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
       .once()
 
-    val newMetadata = txnMetadata.copy().prepareAbortOrCommit(PrepareAbort, time.milliseconds() + TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
+    val expectedTransition = TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, PrepareAbort,
+      partitions.toSet, now, now + TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
 
     EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
       EasyMock.eq(coordinatorEpoch),
-      EasyMock.eq(newMetadata),
+      EasyMock.eq(expectedTransition),
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {}
@@ -563,11 +580,12 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = {
-    val metadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
+    val metadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, Ongoing,
+      partitions, time.milliseconds(), time.milliseconds())
     metadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds())
 
     EasyMock.expect(transactionManager.transactionsToExpire())
-      .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, pid, epoch)))
+      .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
 
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
@@ -583,7 +601,8 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true).anyTimes()
 
-    val metadata = new TransactionMetadata(0, 0, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0)
+    val metadata = new TransactionMetadata(transactionalId, 0, 0, 0, state,
+      mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
       .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))).anyTimes()
 
@@ -600,11 +619,12 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
 
-    val metadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds())
+    val metadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, state,
+      mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds())
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
       .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))
 
-    val capturedNewMetadata: Capture[TransactionMetadataTransition] = EasyMock.newCapture()
+    val capturedNewMetadata: Capture[TxnTransitMetadata] = EasyMock.newCapture()
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
       EasyMock.eq(coordinatorEpoch),
@@ -622,16 +642,20 @@ class TransactionCoordinatorTest {
     val newTxnTimeoutMs = 10
     coordinator.handleInitProducerId(transactionalId, newTxnTimeoutMs, initProducerIdMockCallback)
 
-    assertEquals(InitProducerIdResult(pid, (epoch + 1).toShort, Errors.NONE), result)
+    assertEquals(InitProducerIdResult(producerId, (producerEpoch + 1).toShort, Errors.NONE), result)
     assertEquals(newTxnTimeoutMs, metadata.txnTimeoutMs)
     assertEquals(time.milliseconds(), metadata.txnLastUpdateTimestamp)
-    assertEquals((epoch + 1).toShort, metadata.producerEpoch)
-    assertEquals(pid, metadata.producerId)
+    assertEquals((producerEpoch + 1).toShort, metadata.producerEpoch)
+    assertEquals(producerId, metadata.producerId)
   }
 
   private def mockPrepare(transactionState: TransactionState, runCallback: Boolean = false): TransactionMetadata = {
     val now = time.milliseconds()
-    val originalMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, now, now)
+    val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs,
+      Ongoing, partitions, now, now)
+
+    val transition = TxnTransitMetadata(producerId, producerEpoch, txnTimeoutMs, transactionState,
+      partitions.toSet, now, now)
 
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
@@ -642,7 +666,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
       EasyMock.eq(coordinatorEpoch),
-      EasyMock.eq(originalMetadata.copy().prepareAbortOrCommit(transactionState, now)),
+      EasyMock.eq(transition),
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {
@@ -651,7 +675,8 @@ class TransactionCoordinatorTest {
         }
       }).once()
 
-    new TransactionMetadata(pid, epoch, txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds())
+    new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, transactionState, partitions,
+      time.milliseconds(), time.milliseconds())
   }
 
   private def mockComplete(transactionState: TransactionState, appendError: Errors = Errors.NONE): TransactionMetadata = {
@@ -663,7 +688,7 @@ class TransactionCoordinatorTest {
     else
       (CompleteCommit, TransactionResult.COMMIT)
 
-    val completedMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, finalState,
+    val completedMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs, finalState,
       collection.mutable.Set.empty[TopicPartition],
       prepareMetadata.txnStartTimestamp,
       prepareMetadata.txnLastUpdateTimestamp)
@@ -672,8 +697,8 @@ class TransactionCoordinatorTest {
       .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, prepareMetadata)))
       .once()
 
-    val newMetadata = TransactionMetadataTransition(producerId = pid,
-      producerEpoch = epoch,
+    val newMetadata = TxnTransitMetadata(producerId = producerId,
+      producerEpoch = producerEpoch,
       txnTimeoutMs = txnTimeoutMs,
       txnState = finalState,
       topicPartitions = Set.empty[TopicPartition],

http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
index fe750b8..c0edec7 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
@@ -28,7 +28,7 @@ import scala.collection.JavaConverters._
 
 class TransactionLogTest extends JUnitSuite {
 
-  val epoch: Short = 0
+  val producerEpoch: Short = 0
   val transactionTimeoutMs: Int = 1000
 
   val topicPartitions: Set[TopicPartition] = Set[TopicPartition](new TopicPartition("topic1", 0),
@@ -39,7 +39,10 @@ class TransactionLogTest extends JUnitSuite {
 
   @Test
   def shouldThrowExceptionWriteInvalidTxn() {
-    val txnMetadata = TransactionMetadata(0L, epoch, transactionTimeoutMs, 0)
+    val transactionalId = "transactionalId"
+    val producerId = 23423L
+
+    val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs, 0)
     txnMetadata.addPartitions(topicPartitions)
 
     intercept[IllegalStateException] {
@@ -64,8 +67,9 @@ class TransactionLogTest extends JUnitSuite {
       5L -> CompleteAbort)
 
     // generate transaction log messages
-    val txnRecords = pidMappings.map { case (transactionalId, pid) =>
-      val txnMetadata = TransactionMetadata(pid, epoch, transactionTimeoutMs, transactionStates(pid), 0)
+    val txnRecords = pidMappings.map { case (transactionalId, producerId) =>
+      val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs,
+        transactionStates(producerId), 0)
 
       if (!txnMetadata.state.equals(Empty))
         txnMetadata.addPartitions(topicPartitions)
@@ -80,27 +84,21 @@ class TransactionLogTest extends JUnitSuite {
 
     var count = 0
     for (record <- records.records.asScala) {
-      val key = TransactionLog.readMessageKey(record.key())
-
-      key match {
-        case pidKey: TxnKey =>
-          val transactionalId = pidKey.transactionalId
-          val txnMetadata = TransactionLog.readMessageValue(record.value())
-
-          assertEquals(pidMappings(transactionalId), txnMetadata.producerId)
-          assertEquals(epoch, txnMetadata.producerEpoch)
-          assertEquals(transactionTimeoutMs, txnMetadata.txnTimeoutMs)
-          assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state)
+      val txnKey = TransactionLog.readTxnRecordKey(record.key)
+      val transactionalId = txnKey.transactionalId
+      val txnMetadata = TransactionLog.readTxnRecordValue(transactionalId, record.value)
 
-          if (txnMetadata.state.equals(Empty))
-            assertEquals(Set.empty[TopicPartition], txnMetadata.topicPartitions)
-          else
-            assertEquals(topicPartitions, txnMetadata.topicPartitions)
+      assertEquals(pidMappings(transactionalId), txnMetadata.producerId)
+      assertEquals(producerEpoch, txnMetadata.producerEpoch)
+      assertEquals(transactionTimeoutMs, txnMetadata.txnTimeoutMs)
+      assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state)
 
-          count = count + 1
+      if (txnMetadata.state.equals(Empty))
+        assertEquals(Set.empty[TopicPartition], txnMetadata.topicPartitions)
+      else
+        assertEquals(topicPartitions, txnMetadata.topicPartitions)
 
-        case _ => fail(s"Unexpected transaction topic message key $key")
-      }
+      count = count + 1
     }
 
     assertEquals(pidMappings.size, count)

http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
index d02e072..9835db7 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
@@ -102,7 +102,8 @@ class TransactionMarkerChannelManagerTest {
 
     EasyMock.replay(metadataCache)
 
-    val txnMetadata = new TransactionMetadata(producerId1, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L)
+    val txnMetadata = new TransactionMetadata(transactionalId1, producerId1, producerEpoch, txnTimeoutMs,
+      PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L)
     channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
 
     assertEquals(1 * 2, txnMarkerPurgatory.watched)
@@ -137,8 +138,10 @@ class TransactionMarkerChannelManagerTest {
 
     EasyMock.replay(metadataCache)
 
-    val txnMetadata1 = new TransactionMetadata(producerId1, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
-    val txnMetadata2 = new TransactionMetadata(producerId2, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerEpoch, txnTimeoutMs,
+      PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerEpoch, txnTimeoutMs,
+      PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
     channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds()))
     channelManager.addTxnMarkersToSend(transactionalId2, coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds()))
 
@@ -191,10 +194,12 @@ class TransactionMarkerChannelManagerTest {
 
     EasyMock.replay(metadataCache)
 
-    val txnMetadata1 = new TransactionMetadata(producerId1, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerEpoch, txnTimeoutMs,
+      PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
     channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds()))
 
-    val txnMetadata2 = new TransactionMetadata(producerId2, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerEpoch, txnTimeoutMs,
+      PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
     channelManager.addTxnMarkersToSend(transactionalId2, coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds()))
 
     assertEquals(2 * 2, txnMarkerPurgatory.watched)

http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
index 082d441..c1123aa 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
@@ -44,7 +44,8 @@ class TransactionMarkerRequestCompletionHandlerTest {
     Utils.mkList(
       TxnIdAndMarkerEntry(transactionalId, new WriteTxnMarkersRequest.TxnMarkerEntry(producerId, producerEpoch, coordinatorEpoch, txnResult, Utils.mkList(topicPartition))))
 
-  private val txnMetadata = new TransactionMetadata(producerId, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L)
+  private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerEpoch, txnTimeoutMs,
+    PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L)
 
   private val markerChannelManager = EasyMock.createNiceMock(classOf[TransactionMarkerChannelManager])
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/70ec4b1d/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index fb443ad..bbf2f38 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -62,21 +62,21 @@ class TransactionStateManagerTest {
   val txnConfig = TransactionConfig()
   val transactionManager: TransactionStateManager = new TransactionStateManager(0, zkUtils, scheduler, replicaManager, txnConfig, time)
 
-  val txnId1: String = "one"
-  val txnId2: String = "two"
-  val txnMessageKeyBytes1: Array[Byte] = TransactionLog.keyToBytes(txnId1)
-  val txnMessageKeyBytes2: Array[Byte] = TransactionLog.keyToBytes(txnId2)
-  val pidMappings: Map[String, Long] = Map[String, Long](txnId1 -> 1L, txnId2 -> 2L)
-  var txnMetadata1: TransactionMetadata = TransactionMetadata(pidMappings(txnId1), 1, transactionTimeoutMs, 0)
-  var txnMetadata2: TransactionMetadata = TransactionMetadata(pidMappings(txnId2), 1, transactionTimeoutMs, 0)
+  val transactionalId1: String = "one"
+  val transactionalId2: String = "two"
+  val txnMessageKeyBytes1: Array[Byte] = TransactionLog.keyToBytes(transactionalId1)
+  val txnMessageKeyBytes2: Array[Byte] = TransactionLog.keyToBytes(transactionalId2)
+  val producerIds: Map[String, Long] = Map[String, Long](transactionalId1 -> 1L, transactionalId2 -> 2L)
+  var txnMetadata1: TransactionMetadata = transactionMetadata(transactionalId1, producerIds(transactionalId1))
+  var txnMetadata2: TransactionMetadata = transactionMetadata(transactionalId2, producerIds(transactionalId2))
 
   var expectedError: Errors = Errors.NONE
 
   @Before
   def setUp() {
     // make sure the transactional id hashes to the assigning partition id
-    assertEquals(partitionId, transactionManager.partitionFor(txnId1))
-    assertEquals(partitionId, transactionManager.partitionFor(txnId2))
+    assertEquals(partitionId, transactionManager.partitionFor(transactionalId1))
+    assertEquals(partitionId, transactionManager.partitionFor(transactionalId2))
   }
 
   @After
@@ -98,10 +98,10 @@ class TransactionStateManagerTest {
   def testAddGetPids() {
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
 
-    assertEquals(None, transactionManager.getTransactionState(txnId1))
-    assertEquals(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1), transactionManager.addTransaction(txnId1, txnMetadata1))
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
-    assertEquals(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1), transactionManager.addTransaction(txnId1, txnMetadata2))
+    assertEquals(None, transactionManager.getTransactionState(transactionalId1))
+    assertEquals(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1), transactionManager.addTransaction(txnMetadata1))
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
+    assertEquals(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata2), transactionManager.addTransaction(txnMetadata2))
   }
 
   @Test
@@ -157,35 +157,35 @@ class TransactionStateManagerTest {
     prepareTxnLog(topicPartition, startOffset, records)
 
     // this partition should not be part of the owned partitions
-    assertFalse(transactionManager.isCoordinatorFor(txnId1))
-    assertFalse(transactionManager.isCoordinatorFor(txnId2))
+    assertFalse(transactionManager.isCoordinatorFor(transactionalId1))
+    assertFalse(transactionManager.isCoordinatorFor(transactionalId2))
 
     transactionManager.loadTransactionsForTxnTopicPartition(partitionId, 0, (_, _, _, _, _) => ())
 
     // let the time advance to trigger the background thread loading
     scheduler.tick()
 
-    val cachedPidMetadata1 = transactionManager.getTransactionState(txnId1).getOrElse(fail(txnId1 + "'s transaction state was not loaded into the cache"))
-    val cachedPidMetadata2 = transactionManager.getTransactionState(txnId2).getOrElse(fail(txnId2 + "'s transaction state was not loaded into the cache"))
+    val cachedPidMetadata1 = transactionManager.getTransactionState(transactionalId1).getOrElse(fail(transactionalId1 + "'s transaction state was not loaded into the cache"))
+    val cachedPidMetadata2 = transactionManager.getTransactionState(transactionalId2).getOrElse(fail(transactionalId2 + "'s transaction state was not loaded into the cache"))
 
     // they should be equal to the latest status of the transaction
     assertEquals(txnMetadata1, cachedPidMetadata1.transactionMetadata)
     assertEquals(txnMetadata2, cachedPidMetadata2.transactionMetadata)
 
     // this partition should now be part of the owned partitions
-    assertTrue(transactionManager.isCoordinatorFor(txnId1))
-    assertTrue(transactionManager.isCoordinatorFor(txnId2))
+    assertTrue(transactionManager.isCoordinatorFor(transactionalId1))
+    assertTrue(transactionManager.isCoordinatorFor(transactionalId2))
 
     transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch)
 
     // let the time advance to trigger the background thread removing
     scheduler.tick()
 
-    assertFalse(transactionManager.isCoordinatorFor(txnId1))
-    assertFalse(transactionManager.isCoordinatorFor(txnId2))
+    assertFalse(transactionManager.isCoordinatorFor(transactionalId1))
+    assertFalse(transactionManager.isCoordinatorFor(transactionalId2))
 
-    assertEquals(None, transactionManager.getTransactionState(txnId1))
-    assertEquals(None, transactionManager.getTransactionState(txnId2))
+    assertEquals(None, transactionManager.getTransactionState(transactionalId1))
+    assertEquals(None, transactionManager.getTransactionState(transactionalId2))
   }
 
   @Test
@@ -193,7 +193,7 @@ class TransactionStateManagerTest {
     transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
 
     // first insert the initial transaction metadata
-    transactionManager.addTransaction(txnId1, txnMetadata1)
+    transactionManager.addTransaction(txnMetadata1)
 
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.NONE
@@ -203,9 +203,9 @@ class TransactionStateManagerTest {
       new TopicPartition("topic1", 1)), time.milliseconds())
 
     // append the new metadata into log
-    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch, newMetadata, assertCallback)
+    transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch, newMetadata, assertCallback)
 
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
 
     // append to log again with expected failures
     val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds())
@@ -214,38 +214,38 @@ class TransactionStateManagerTest {
     expectedError = Errors.COORDINATOR_NOT_AVAILABLE
 
     prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION)
-    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
 
     prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS)
-    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
 
     prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND)
-    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
 
     prepareForTxnMessageAppend(Errors.REQUEST_TIMED_OUT)
-    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
 
     // test NOT_COORDINATOR cases
     expectedError = Errors.NOT_COORDINATOR
 
     prepareForTxnMessageAppend(Errors.NOT_LEADER_FOR_PARTITION)
-    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
 
     // test NOT_COORDINATOR cases
     expectedError = Errors.UNKNOWN
 
     prepareForTxnMessageAppend(Errors.MESSAGE_TOO_LARGE)
-    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
 
     prepareForTxnMessageAppend(Errors.RECORD_LIST_TOO_LARGE)
-    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
-    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(transactionalId1))
   }
 
   @Test
@@ -253,7 +253,7 @@ class TransactionStateManagerTest {
     transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
 
     // first insert the initial transaction metadata
-    transactionManager.addTransaction(txnId1, txnMetadata1)
+    transactionManager.addTransaction(txnMetadata1)
 
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.NOT_COORDINATOR
@@ -265,13 +265,13 @@ class TransactionStateManagerTest {
     txnMetadata1.producerEpoch = (txnMetadata1.producerEpoch + 1).toShort
 
     // append the new metadata into log
-    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, newMetadata, assertCallback)
+    transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, newMetadata, assertCallback)
   }
 
   @Test(expected = classOf[IllegalStateException])
   def testAppendTransactionToLogWhilePendingStateChanged() = {
     // first insert the initial transaction metadata
-    transactionManager.addTransaction(txnId1, txnMetadata1)
+    transactionManager.addTransaction(txnMetadata1)
 
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.INVALID_PRODUCER_EPOCH
@@ -283,12 +283,12 @@ class TransactionStateManagerTest {
     txnMetadata1.pendingState = None
 
     // append the new metadata into log
-    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, newMetadata, assertCallback)
+    transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, newMetadata, assertCallback)
   }
 
   @Test
   def shouldReturnNoneIfTransactionIdPartitionNotOwned(): Unit = {
-    assertEquals(None, transactionManager.getTransactionState(txnId1))
+    assertEquals(None, transactionManager.getTransactionState(transactionalId1))
   }
 
   @Test
@@ -297,34 +297,16 @@ class TransactionStateManagerTest {
       transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
     }
 
-    txnMetadata1.state = Ongoing
-    txnMetadata1.txnStartTimestamp = time.milliseconds()
-    transactionManager.addTransaction(txnId1, txnMetadata1)
-    transactionManager.addTransaction(txnId2, txnMetadata2)
-
-    val ongoingButNotExpiring = txnMetadata1.copy()
-    ongoingButNotExpiring.txnTimeoutMs = 10000
-    transactionManager.addTransaction("not-expiring", ongoingButNotExpiring)
-
-    val prepareCommit = txnMetadata1.copy()
-    prepareCommit.state = PrepareCommit
-    transactionManager.addTransaction("pc", prepareCommit)
-
-    val prepareAbort = txnMetadata1.copy()
-    prepareAbort.state = PrepareAbort
-    transactionManager.addTransaction("pa", prepareAbort)
-
-    val committed = txnMetadata1.copy()
-    committed.state = CompleteCommit
-    transactionManager.addTransaction("cc", committed)
-
-    val aborted = txnMetadata1.copy()
-    aborted.state = CompleteAbort
-    transactionManager.addTransaction("ca", aborted)
+    transactionManager.addTransaction(transactionMetadata("ongoing", producerId = 0, state = Ongoing))
+    transactionManager.addTransaction(transactionMetadata("not-expiring", producerId = 1, state = Ongoing, txnTimeout = 10000))
+    transactionManager.addTransaction(transactionMetadata("prepare-commit", producerId = 2, state = PrepareCommit))
+    transactionManager.addTransaction(transactionMetadata("prepare-abort", producerId = 3, state = PrepareAbort))
+    transactionManager.addTransaction(transactionMetadata("complete-commit", producerId = 4, state = CompleteCommit))
+    transactionManager.addTransaction(transactionMetadata("complete-abort", producerId = 5, state = CompleteAbort))
 
     time.sleep(2000)
     val expiring = transactionManager.transactionsToExpire()
-    assertEquals(List(TransactionalIdAndProducerIdEpoch(txnId1, txnMetadata1.producerId, txnMetadata1.producerEpoch)), expiring)
+    assertEquals(List(TransactionalIdAndProducerIdEpoch("ongoing", 0, 0)), expiring)
   }
 
   @Test
@@ -353,20 +335,27 @@ class TransactionStateManagerTest {
                            coordinatorEpoch: Int,
                            command: TransactionResult,
                            metadata: TransactionMetadata,
-                           newMetadata: TransactionMetadataTransition): Unit = {
+                           newMetadata: TxnTransitMetadata): Unit = {
       txnId = transactionalId
     }
 
     transactionManager.loadTransactionsForTxnTopicPartition(partitionId, 0, rememberTxnMarkers)
     scheduler.tick()
 
-    assertEquals(txnId1, txnId)
+    assertEquals(transactionalId1, txnId)
   }
 
   private def assertCallback(error: Errors): Unit = {
     assertEquals(expectedError, error)
   }
 
+  private def transactionMetadata(transactionalId: String,
+                                  producerId: Long,
+                                  state: TransactionState = Empty,
+                                  txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = {
+    TransactionMetadata(transactionalId, producerId, 0.toShort, txnTimeout, state, time.milliseconds())
+  }
+
   private def prepareTxnLog(topicPartition: TopicPartition,
                             startOffset: Long,
                             records: MemoryRecords): Unit = {