You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2017/05/12 22:01:05 UTC

[1/3] kafka git commit: KAFKA-5130: Refactor transaction coordinator's in-memory cache; plus fixes on transaction metadata synchronization

Repository: kafka
Updated Branches:
  refs/heads/trunk 7baa58d79 -> 794e6dbd1


http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
index 29240a6..d02e072 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
@@ -16,17 +16,12 @@
  */
 package kafka.coordinator.transaction
 
-import kafka.api.{LeaderAndIsr, PartitionStateInfo}
-import kafka.common.{BrokerEndPointNotAvailableException, InterBrokerSendThread}
-import kafka.controller.LeaderIsrAndControllerEpoch
 import kafka.server.{DelayedOperationPurgatory, KafkaConfig, MetadataCache}
-import kafka.utils.{MockTime, TestUtils}
 import kafka.utils.timer.MockTimer
+import kafka.utils.TestUtils
 import org.apache.kafka.clients.NetworkClient
-import org.apache.kafka.common.network.ListenerName
-import org.apache.kafka.common.protocol.{Errors, SecurityProtocol}
 import org.apache.kafka.common.requests.{TransactionResult, WriteTxnMarkersRequest}
-import org.apache.kafka.common.utils.Utils
+import org.apache.kafka.common.utils.{MockTime, Utils}
 import org.apache.kafka.common.{Node, TopicPartition}
 import org.easymock.EasyMock
 import org.junit.Assert._
@@ -36,241 +31,182 @@ import scala.collection.mutable
 
 class TransactionMarkerChannelManagerTest {
   private val metadataCache = EasyMock.createNiceMock(classOf[MetadataCache])
-  private val interBrokerSendThread = EasyMock.createNiceMock(classOf[InterBrokerSendThread])
   private val networkClient = EasyMock.createNiceMock(classOf[NetworkClient])
-  private val channel = new TransactionMarkerChannel(ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT),
-    metadataCache,
-    networkClient,
-    new MockTime())
-  private val purgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("txn-purgatory-name",
-    new MockTimer,
-    reaperEnabled = false)
-  private val requestGenerator = TransactionMarkerChannelManager.requestGenerator(channel, purgatory)
+  private val txnStateManager = EasyMock.createNiceMock(classOf[TransactionStateManager])
+
   private val partition1 = new TopicPartition("topic1", 0)
   private val partition2 = new TopicPartition("topic1", 1)
   private val broker1 = new Node(1, "host", 10)
   private val broker2 = new Node(2, "otherhost", 10)
-  private val metadataPartition = 0
+
+  private val transactionalId1 = "txnId1"
+  private val transactionalId2 = "txnId2"
+  private val transactionalId3 = "txnId3"
+  private val producerId1 = 0.asInstanceOf[Long]
+  private val producerId2 = 1.asInstanceOf[Long]
+  private val producerId3 = 1.asInstanceOf[Long]
+  private val producerEpoch = 0.asInstanceOf[Short]
+  private val txnTopicPartition1 = 0
+  private val txnTopicPartition2 = 1
+  private val coordinatorEpoch = 0
+  private val txnTimeoutMs = 0
+  private val txnResult = TransactionResult.COMMIT
+
+  private val txnMarkerPurgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("txn-purgatory-name",
+    new MockTimer,
+    reaperEnabled = false)
+  private val time = new MockTime
+
   private val channelManager = new TransactionMarkerChannelManager(
     KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:2181")),
     metadataCache,
-    purgatory,
-    interBrokerSendThread,
-    channel)
+    networkClient,
+    txnStateManager,
+    txnMarkerPurgatory,
+    time)
+
+  private val senderThread = channelManager.senderThread
+
+  private def mockCache(): Unit = {
+    EasyMock.expect(txnStateManager.partitionFor(transactionalId1))
+      .andReturn(txnTopicPartition1)
+      .anyTimes()
+    EasyMock.expect(txnStateManager.partitionFor(transactionalId2))
+      .andReturn(txnTopicPartition2)
+      .anyTimes()
+    EasyMock.replay(txnStateManager)
+  }
 
   @Test
   def shouldGenerateEmptyMapWhenNoRequestsOutstanding(): Unit = {
-    assertTrue(requestGenerator().isEmpty)
+    assertTrue(senderThread.generateRequests().isEmpty)
   }
 
   @Test
-  def shouldGenerateRequestPerBroker(): Unit ={
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
+  def shouldGenerateRequestPerBroker(): Unit = {
+    mockCache()
 
-    EasyMock.expect(metadataCache.getPartitionInfo(partition2.topic(), partition2.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(2, 0, List.empty, 0), 0), Set.empty)))
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.eq(partition1.topic),
+      EasyMock.eq(partition1.partition),
+      EasyMock.anyObject()))
+      .andReturn(Some(broker1))
+      .anyTimes()
 
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(2), EasyMock.anyObject())).andReturn(Some(broker2))
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.eq(partition2.topic),
+      EasyMock.eq(partition2.partition),
+      EasyMock.anyObject()))
+      .andReturn(Some(broker2))
+      .anyTimes()
 
     EasyMock.replay(metadataCache)
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1, partition2))
+    val txnMetadata = new TransactionMetadata(producerId1, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L)
+    channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
 
+    assertEquals(1 * 2, txnMarkerPurgatory.watched)
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
+    assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers())
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
+    assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition1))
 
     val expectedBroker1Request = new WriteTxnMarkersRequest.Builder(
-      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList(partition1)))).build()
+      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, TransactionResult.COMMIT, Utils.mkList(partition1)))).build()
     val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(
-      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList(partition2)))).build()
+      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, TransactionResult.COMMIT, Utils.mkList(partition2)))).build()
 
-    val requests: Map[Node, WriteTxnMarkersRequest] = requestGenerator().map{ result =>
-      (result.destination, result.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build())
+    val requests: Map[Node, WriteTxnMarkersRequest] = senderThread.generateRequests().map { handler =>
+      (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build())
     }.toMap
 
-    val broker1Request = requests(broker1)
-    val broker2Request = requests(broker2)
-
-    assertEquals(2, requests.size)
-    assertEquals(expectedBroker1Request, broker1Request)
-    assertEquals(expectedBroker2Request, broker2Request)
-
+    assertEquals(Map(broker1 -> expectedBroker1Request, broker2 -> expectedBroker2Request), requests)
+    assertTrue(senderThread.generateRequests().isEmpty)
   }
 
   @Test
-  def shouldGenerateRequestPerPartitionPerBroker(): Unit ={
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
+  def shouldGenerateRequestPerPartitionPerBroker(): Unit = {
+    mockCache()
 
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.eq(partition1.topic),
+      EasyMock.eq(partition1.partition),
+      EasyMock.anyObject()))
+      .andReturn(Some(broker1))
+      .anyTimes()
 
-    EasyMock.expect(metadataCache.getPartitionInfo(partition2.topic(), partition2.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1)).anyTimes()
     EasyMock.replay(metadataCache)
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
-    channel.addRequestToSend(1, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition2))
+    val txnMetadata1 = new TransactionMetadata(producerId1, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    val txnMetadata2 = new TransactionMetadata(producerId2, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds()))
+    channelManager.addTxnMarkersToSend(transactionalId2, coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds()))
 
-    val expectedPartition1Request = new WriteTxnMarkersRequest.Builder(
-      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList(partition1)))).build()
-    val expectedPartition2Request = new WriteTxnMarkersRequest.Builder(
-      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList(partition2)))).build()
+    assertEquals(2 * 2, txnMarkerPurgatory.watched)
+    assertEquals(2, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2))
 
-    val requests = requestGenerator().map { result =>
-      val markersRequest = result.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()
-      (result.destination, markersRequest)
-    }.toList
-
-    assertEquals(List((broker1, expectedPartition1Request), (broker1, expectedPartition2Request)), requests)
-  }
-
-  @Test
-  def shouldDrainBrokerQueueWhenGeneratingRequests(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-    EasyMock.replay(metadataCache)
+    val expectedBroker1Request = new WriteTxnMarkersRequest.Builder(
+      Utils.mkList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, TransactionResult.COMMIT, Utils.mkList(partition1)),
+        new WriteTxnMarkersRequest.TxnMarkerEntry(producerId2, producerEpoch, coordinatorEpoch, TransactionResult.COMMIT, Utils.mkList(partition1)))).build()
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
+    val requests: Map[Node, WriteTxnMarkersRequest] = senderThread.generateRequests().map { handler =>
+      (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build())
+    }.toMap
 
-    val result = requestGenerator()
-    assertTrue(result.nonEmpty)
-    val result2 = requestGenerator()
-    assertTrue(result2.isEmpty)
+    assertEquals(Map(broker1 -> expectedBroker1Request), requests)
+    assertTrue(senderThread.generateRequests().isEmpty)
   }
 
   @Test
   def shouldRetryGettingLeaderWhenNotFound(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(None)
-      .andReturn(None)
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
+    mockCache()
 
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-    EasyMock.replay(metadataCache)
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.eq(partition1.topic),
+      EasyMock.eq(partition1.partition),
+      EasyMock.anyObject())
+    ).andReturn(None)
+     .andReturn(None)
+     .andReturn(Some(broker1))
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
-
-    EasyMock.verify(metadataCache)
-  }
-
-  @Test
-  def shouldRetryGettingLeaderWhenBrokerEndPointNotAvailableException(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-      .times(2)
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject()))
-      .andThrow(new BrokerEndPointNotAvailableException())
-      .andReturn(Some(broker1))
     EasyMock.replay(metadataCache)
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
+    channelManager.addTxnMarkersToBrokerQueue(transactionalId1, producerId1, producerEpoch, TransactionResult.COMMIT, coordinatorEpoch, Set[TopicPartition](partition1))
 
     EasyMock.verify(metadataCache)
   }
 
   @Test
-  def shouldRetryGettingLeaderWhenLeaderDoesntExist(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-      .times(2)
+  def shouldRemoveMarkersForTxnPartitionWhenPartitionEmigrated(): Unit = {
+    mockCache()
 
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject()))
-      .andReturn(None)
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.eq(partition1.topic),
+      EasyMock.eq(partition1.partition),
+      EasyMock.anyObject()))
       .andReturn(Some(broker1))
+      .anyTimes()
 
     EasyMock.replay(metadataCache)
 
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
+    val txnMetadata1 = new TransactionMetadata(producerId1, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    channelManager.addTxnMarkersToSend(transactionalId1, coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds()))
 
-    EasyMock.verify(metadataCache)
-  }
-
-  @Test
-  def shouldAddPendingTxnRequest(): Unit = {
-    val metadata = new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0, 0L)
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getPartitionInfo(partition2.topic(), partition2.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(2, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(2), EasyMock.anyObject())).andReturn(Some(broker2))
-
-    EasyMock.replay(metadataCache)
-
-    channelManager.addTxnMarkerRequest(metadataPartition, metadata, 0, completionCallback)
+    val txnMetadata2 = new TransactionMetadata(producerId2, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
+    channelManager.addTxnMarkersToSend(transactionalId2, coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds()))
 
-    assertEquals(Some(metadata), channel.pendingTxnMetadata(metadataPartition, 1))
-
-  }
+    assertEquals(2 * 2, txnMarkerPurgatory.watched)
+    assertEquals(2, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2))
 
-  @Test
-  def shouldAddRequestToBrokerQueue(): Unit = {
-    val metadata = new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1), 0, 0L)
-
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-    EasyMock.replay(metadataCache)
-
-    channelManager.addTxnMarkerRequest(metadataPartition, metadata, 0, completionCallback)
-    assertEquals(1, requestGenerator().size)
-  }
-
-  @Test
-  def shouldAddDelayedTxnMarkerToPurgatory(): Unit = {
-    val metadata = new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1), 0, 0L)
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(EasyMock.eq(1), EasyMock.anyObject())).andReturn(Some(broker1))
-
-    EasyMock.replay(metadataCache)
-
-    channelManager.addTxnMarkerRequest(metadataPartition, metadata, 0, completionCallback)
-    assertEquals(1,purgatory.watched)
-  }
-
-  @Test
-  def shouldStartInterBrokerThreadOnStartup(): Unit = {
-    EasyMock.expect(interBrokerSendThread.start())
-    EasyMock.replay(interBrokerSendThread)
-    channelManager.start()
-    EasyMock.verify(interBrokerSendThread)
-  }
-
-
-  @Test
-  def shouldStopInterBrokerThreadOnShutdown(): Unit = {
-    EasyMock.expect(interBrokerSendThread.shutdown())
-    EasyMock.replay(interBrokerSendThread)
-    channelManager.shutdown()
-    EasyMock.verify(interBrokerSendThread)
-  }
-
-  @Test
-  def shouldClearPurgatoryForPartitionWhenPartitionEmigrated(): Unit = {
-    val metadata1 = new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1), 0, 0)
-    purgatory.tryCompleteElseWatch(new DelayedTxnMarker(metadata1, (error:Errors) => {}),Seq(0L))
-    channel.maybeAddPendingRequest(0, metadata1)
-
-    val metadata2 = new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1), 0, 0)
-    purgatory.tryCompleteElseWatch(new DelayedTxnMarker(metadata2, (error:Errors) => {}),Seq(1L))
-    channel.maybeAddPendingRequest(0, metadata2)
-
-    val metadata3 = new TransactionMetadata(2, 0, 0, PrepareCommit, mutable.Set[TopicPartition](partition1), 0, 0)
-    purgatory.tryCompleteElseWatch(new DelayedTxnMarker(metadata3, (error:Errors) => {}),Seq(2L))
-    channel.maybeAddPendingRequest(1, metadata3)
-
-    channelManager.removeStateForPartition(0)
-
-    assertEquals(1, purgatory.watched)
-    // should not complete as they've been removed
-    purgatory.checkAndComplete(0L)
-    purgatory.checkAndComplete(1L)
-    
-    assertEquals(1, purgatory.watched)
-  }
+    channelManager.removeMarkersForTxnTopicPartition(txnTopicPartition1)
 
-  def completionCallback(errors: Errors): Unit = {
+    assertEquals(1 * 2, txnMarkerPurgatory.watched)
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers())
+    assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1))
+    assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2))
   }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelTest.scala
deleted file mode 100644
index 89a7606..0000000
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelTest.scala
+++ /dev/null
@@ -1,179 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package kafka.coordinator.transaction
-
-import kafka.api.{LeaderAndIsr, PartitionStateInfo}
-import kafka.controller.LeaderIsrAndControllerEpoch
-import kafka.server.{DelayedOperationPurgatory, MetadataCache}
-import kafka.utils.MockTime
-import kafka.utils.timer.MockTimer
-import org.apache.kafka.clients.NetworkClient
-import org.apache.kafka.common.network.ListenerName
-import org.apache.kafka.common.protocol.{Errors, SecurityProtocol}
-import org.apache.kafka.common.requests.{TransactionResult, WriteTxnMarkersRequest}
-import org.apache.kafka.common.utils.Utils
-import org.apache.kafka.common.{Node, TopicPartition}
-import org.easymock.EasyMock
-import org.junit.Assert._
-import org.junit.Test
-
-import scala.collection.mutable
-
-class TransactionMarkerChannelTest {
-
-  private val metadataCache = EasyMock.createNiceMock(classOf[MetadataCache])
-  private val networkClient = EasyMock.createNiceMock(classOf[NetworkClient])
-  private val purgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("name", new MockTimer, reaperEnabled = false)
-  private val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)
-  private val channel = new TransactionMarkerChannel(listenerName, metadataCache, networkClient, new MockTime())
-  private val partition1 = new TopicPartition("topic1", 0)
-
-
-  @Test
-  def shouldAddEmptyBrokerQueueWhenAddingNewBroker(): Unit = {
-    channel.addOrUpdateBroker(new Node(1, "host", 10))
-    channel.addOrUpdateBroker(new Node(2, "host", 10))
-    assertEquals(0, channel.queueForBroker(1).get.eachMetadataPartition{case(partition:Int, _) => partition}.size)
-    assertEquals(0, channel.queueForBroker(2).get.eachMetadataPartition{case(partition:Int, _) => partition}.size)
-  }
-
-  @Test
-  def shouldUpdateDestinationBrokerNodeWhenUpdatingBroker(): Unit = {
-    val newDestination = new Node(1, "otherhost", 100)
-
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-
-    // getAliveEndpoint returns an updated node
-    EasyMock.expect(metadataCache.getAliveEndpoint(1, listenerName)).andReturn(Some(newDestination))
-    EasyMock.replay(metadataCache)
-
-    channel.addOrUpdateBroker(new Node(1, "host", 10))
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
-
-    val brokerRequestQueue = channel.queueForBroker(1).get
-    assertEquals(newDestination, brokerRequestQueue.node)
-    assertEquals(1, brokerRequestQueue.totalQueuedRequests())
-  }
-
-
-  @Test
-  def shouldQueueRequestsByBrokerId(): Unit = {
-    channel.addOrUpdateBroker(new Node(1, "host", 10))
-    channel.addOrUpdateBroker(new Node(2, "otherhost", 10))
-    channel.addRequestForBroker(1, 0, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-    channel.addRequestForBroker(1, 0, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-    channel.addRequestForBroker(2, 0, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-
-    assertEquals(2, channel.queueForBroker(1).get.totalQueuedRequests())
-    assertEquals(1, channel.queueForBroker(2).get.totalQueuedRequests())
-  }
-
-  @Test
-  def shouldNotAddPendingTxnIfOneAlreadyExistsForPid(): Unit = {
-    channel.maybeAddPendingRequest(0, new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0))
-    assertFalse(channel.maybeAddPendingRequest(0, new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0)))
-  }
-
-  @Test
-  def shouldAddRequestsToCorrectBrokerQueues(): Unit = {
-    val partition2 = new TopicPartition("topic1", 1)
-
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-
-    EasyMock.expect(metadataCache.getPartitionInfo(partition2.topic(), partition2.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(2, 0, List.empty, 0), 0), Set.empty)))
-
-    EasyMock.expect(metadataCache.getAliveEndpoint(1, listenerName)).andReturn(Some(new Node(1, "host", 10)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(2, listenerName)).andReturn(Some(new Node(2, "otherhost", 10)))
-
-    EasyMock.replay(metadataCache)
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1, partition2))
-
-    assertEquals(1, channel.queueForBroker(1).get.totalQueuedRequests())
-    assertEquals(1, channel.queueForBroker(2).get.totalQueuedRequests())
-  }
-  @Test
-  def shouldWakeupNetworkClientWhenRequestsQueued(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(1, listenerName)).andReturn(Some(new Node(1, "host", 10)))
-
-    EasyMock.expect(networkClient.wakeup())
-
-    EasyMock.replay(metadataCache, networkClient)
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
-
-    EasyMock.verify(networkClient)
-  }
-
-  @Test
-  def shouldAddNewBrokerQueueIfDoesntAlreadyExistWhenAddingRequest(): Unit = {
-    EasyMock.expect(metadataCache.getPartitionInfo(partition1.topic(), partition1.partition()))
-      .andReturn(Some(PartitionStateInfo(LeaderIsrAndControllerEpoch(LeaderAndIsr(1, 0, List.empty, 0), 0), Set.empty)))
-    EasyMock.expect(metadataCache.getAliveEndpoint(1, listenerName)).andReturn(Some(new Node(1, "host", 10)))
-
-    EasyMock.replay(metadataCache)
-    channel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](partition1))
-
-    assertEquals(1, channel.queueForBroker(1).get.totalQueuedRequests())
-    EasyMock.verify(metadataCache)
-  }
-
-  @Test
-  def shouldGetPendingTxnMetadataByPid(): Unit = {
-    val metadataPartition = 0
-    val transaction = new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0)
-    channel.maybeAddPendingRequest(metadataPartition, transaction)
-    channel.maybeAddPendingRequest(metadataPartition, new TransactionMetadata(2, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0))
-    assertEquals(Some(transaction), channel.pendingTxnMetadata(metadataPartition, 1))
-  }
-
-  @Test
-  def shouldRemovePendingRequestsForPartitionWhenPartitionEmigrated(): Unit = {
-    channel.maybeAddPendingRequest(0, new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0))
-    channel.maybeAddPendingRequest(0, new TransactionMetadata(1, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0))
-    val metadata = new TransactionMetadata(2, 0, 0, PrepareCommit, mutable.Set.empty, 0, 0)
-    channel.maybeAddPendingRequest(1, metadata)
-
-    channel.removeStateForPartition(0)
-
-    assertEquals(None, channel.pendingTxnMetadata(0, 0))
-    assertEquals(None, channel.pendingTxnMetadata(0, 1))
-    assertEquals(Some(metadata), channel.pendingTxnMetadata(1, 2))
-  }
-
-  @Test
-  def shouldRemoveBrokerRequestsForPartitionWhenPartitionEmigrated(): Unit = {
-    channel.addOrUpdateBroker(new Node(1, "host", 10))
-    channel.addRequestForBroker(1, 0, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-    channel.addRequestForBroker(1, 1, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-    channel.addRequestForBroker(1, 1, new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList()))
-
-    channel.removeStateForPartition(1)
-
-
-    val result = channel.queueForBroker(1).get.eachMetadataPartition{case (partition:Int, _) => partition}.toList
-    assertEquals(List(0), result)
-  }
-
-
-
-  def errorCallback(error: Errors): Unit = {}
-}

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
index 096b826..082d441 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala
@@ -18,14 +18,12 @@ package kafka.coordinator.transaction
 
 import java.{lang, util}
 
-import kafka.server.DelayedOperationPurgatory
-import kafka.utils.timer.MockTimer
 import org.apache.kafka.clients.ClientResponse
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse}
 import org.apache.kafka.common.utils.Utils
-import org.easymock.EasyMock
+import org.easymock.{EasyMock, IAnswer}
 import org.junit.Assert._
 import org.junit.Test
 
@@ -33,71 +31,129 @@ import scala.collection.mutable
 
 class TransactionMarkerRequestCompletionHandlerTest {
 
-  private val markerChannel = EasyMock.createNiceMock(classOf[TransactionMarkerChannel])
-  private val purgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("txn-purgatory-name", new MockTimer, reaperEnabled = false)
-  private val topic1 = new TopicPartition("topic1", 0)
-  private val txnMarkers =
+  private val brokerId = 0
+  private val txnTopicPartition = 0
+  private val transactionalId = "txnId1"
+  private val producerId = 0.asInstanceOf[Long]
+  private val producerEpoch = 0.asInstanceOf[Short]
+  private val txnTimeoutMs = 0
+  private val coordinatorEpoch = 0
+  private val txnResult = TransactionResult.COMMIT
+  private val topicPartition = new TopicPartition("topic1", 0)
+  private val txnIdAndMarkers =
     Utils.mkList(
-      new WriteTxnMarkersRequest.TxnMarkerEntry(0, 0, 0, TransactionResult.COMMIT, Utils.mkList(topic1)))
+      TxnIdAndMarkerEntry(transactionalId, new WriteTxnMarkersRequest.TxnMarkerEntry(producerId, producerEpoch, coordinatorEpoch, txnResult, Utils.mkList(topicPartition))))
 
-  private val handler = new TransactionMarkerRequestCompletionHandler(markerChannel, purgatory, 0, txnMarkers, 0)
+  private val txnMetadata = new TransactionMetadata(producerId, producerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L)
+
+  private val markerChannelManager = EasyMock.createNiceMock(classOf[TransactionMarkerChannelManager])
+
+  private val txnStateManager = EasyMock.createNiceMock(classOf[TransactionStateManager])
+
+  private val handler = new TransactionMarkerRequestCompletionHandler(brokerId, txnStateManager, markerChannelManager, txnIdAndMarkers)
+
+  private def mockCache(): Unit = {
+    EasyMock.expect(txnStateManager.partitionFor(transactionalId))
+      .andReturn(txnTopicPartition)
+      .anyTimes()
+    EasyMock.expect(txnStateManager.getTransactionState(transactionalId))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
+      .anyTimes()
+    EasyMock.replay(txnStateManager)
+  }
 
   @Test
   def shouldReEnqueuePartitionsWhenBrokerDisconnected(): Unit = {
-    EasyMock.expect(markerChannel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](topic1)))
-    EasyMock.replay(markerChannel)
+    mockCache()
+
+    EasyMock.expect(markerChannelManager.addTxnMarkersToBrokerQueue(transactionalId,
+      producerId, producerEpoch, txnResult, coordinatorEpoch, Set[TopicPartition](topicPartition)))
+    EasyMock.replay(markerChannelManager)
 
     handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, true, null, null))
 
-    EasyMock.verify(markerChannel)
+    EasyMock.verify(markerChannelManager)
   }
 
   @Test
-  def shouldThrowIllegalStateExceptionIfErrorsNullForPid(): Unit = {
-    val response = new WriteTxnMarkersResponse(new java.util.HashMap[java.lang.Long, java.util.Map[TopicPartition, Errors]]())
+  def shouldThrowIllegalStateExceptionIfErrorCodeNotAvailableForPid(): Unit = {
+    mockCache()
+    EasyMock.replay(markerChannelManager)
 
-    EasyMock.replay(markerChannel)
+    val response = new WriteTxnMarkersResponse(new java.util.HashMap[java.lang.Long, java.util.Map[TopicPartition, Errors]]())
 
     try {
       handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
       fail("should have thrown illegal argument exception")
     } catch {
-      case ise: IllegalStateException => // ok
+      case _: IllegalStateException => // ok
     }
   }
 
   @Test
-  def shouldRemoveCompletedPartitionsFromMetadataWhenNoErrors(): Unit = {
-    val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE))
+  def shouldCompleteDelayedOperationWhenNoErrors(): Unit = {
+    mockCache()
 
-    val metadata = new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set[TopicPartition](topic1), 0, 0)
-    EasyMock.expect(markerChannel.pendingTxnMetadata(0, 0))
-      .andReturn(Some(metadata))
-    EasyMock.replay(markerChannel)
+    verifyCompleteDelayedOperationOnError(Errors.NONE)
+  }
 
-    handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
+  @Test
+  def shouldCompleteDelayedOperationWhenNoMetadata(): Unit = {
+    EasyMock.expect(txnStateManager.getTransactionState(transactionalId))
+      .andReturn(None)
+      .anyTimes()
+    EasyMock.replay(txnStateManager)
 
-    assertTrue(metadata.topicPartitions.isEmpty)
+    verifyRemoveDelayedOperationOnError(Errors.NONE)
   }
 
   @Test
-  def shouldTryCompleteDelayedTxnOperation(): Unit = {
-    val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE))
+  def shouldCompleteDelayedOperationWhenCoordinatorEpochChanged(): Unit = {
+    EasyMock.expect(txnStateManager.getTransactionState(transactionalId))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch+1, txnMetadata)))
+      .anyTimes()
+    EasyMock.replay(txnStateManager)
 
-    val metadata = new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set[TopicPartition](topic1), 0, 0)
-    var completed = false
+    verifyRemoveDelayedOperationOnError(Errors.NONE)
+  }
 
-    purgatory.tryCompleteElseWatch(new DelayedTxnMarker(metadata, (errors:Errors) => {
-      completed = true
-    }), Seq(0L))
+  @Test
+  def shouldCompleteDelayedOperationWhenInvalidProducerEpoch(): Unit = {
+    mockCache()
 
-    EasyMock.expect(markerChannel.pendingTxnMetadata(0, 0))
-      .andReturn(Some(metadata))
+    verifyRemoveDelayedOperationOnError(Errors.INVALID_PRODUCER_EPOCH)
+  }
 
-    EasyMock.replay(markerChannel)
+  @Test
+  def shouldCompleteDelayedOperationWheCoordinatorEpochFenced(): Unit = {
+    mockCache()
 
-    handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
-    assertTrue(completed)
+    verifyRemoveDelayedOperationOnError(Errors.TRANSACTION_COORDINATOR_FENCED)
+  }
+
+  @Test
+  def shouldThrowIllegalStateExceptionWhenUnknownError(): Unit = {
+    verifyThrowIllegalStateExceptionOnError(Errors.UNKNOWN)
+  }
+
+  @Test
+  def shouldThrowIllegalStateExceptionWhenCorruptMessageError(): Unit = {
+    verifyThrowIllegalStateExceptionOnError(Errors.CORRUPT_MESSAGE)
+  }
+
+  @Test
+  def shouldThrowIllegalStateExceptionWhenMessageTooLargeError(): Unit = {
+    verifyThrowIllegalStateExceptionOnError(Errors.MESSAGE_TOO_LARGE)
+  }
+
+  @Test
+  def shouldThrowIllegalStateExceptionWhenRecordListTooLargeError(): Unit = {
+    verifyThrowIllegalStateExceptionOnError(Errors.RECORD_LIST_TOO_LARGE)
+  }
+
+  @Test
+  def shouldThrowIllegalStateExceptionWhenInvalidRequiredAcksError(): Unit = {
+    verifyThrowIllegalStateExceptionOnError(Errors.INVALID_REQUIRED_ACKS)
   }
 
   @Test
@@ -120,40 +176,75 @@ class TransactionMarkerRequestCompletionHandlerTest {
     verifyRetriesPartitionOnError(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND)
   }
 
-  @Test
-  def shouldThrowIllegalStateExceptionWhenErrorNotHandled(): Unit = {
-    val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.UNKNOWN))
-    val metadata = new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set[TopicPartition](topic1), 0, 0)
-    EasyMock.replay(markerChannel)
+  private def verifyRetriesPartitionOnError(error: Errors) = {
+    mockCache()
+
+    EasyMock.expect(markerChannelManager.addTxnMarkersToBrokerQueue(transactionalId,
+      producerId, producerEpoch, txnResult, coordinatorEpoch, Set[TopicPartition](topicPartition)))
+    EasyMock.replay(markerChannelManager)
+
+    val response = new WriteTxnMarkersResponse(createPidErrorMap(error))
+    handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
+
+    assertEquals(txnMetadata.topicPartitions, mutable.Set[TopicPartition](topicPartition))
+    EasyMock.verify(markerChannelManager)
+  }
 
+  private def verifyThrowIllegalStateExceptionOnError(error: Errors) = {
+    mockCache()
+
+    val response = new WriteTxnMarkersResponse(createPidErrorMap(error))
     try {
       handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
       fail("should have thrown illegal state exception")
     } catch {
-      case ise: IllegalStateException => // ol
+      case _: IllegalStateException => // ok
     }
+  }
+
+  private def verifyCompleteDelayedOperationOnError(error: Errors): Unit = {
 
+    var completed = false
+    EasyMock.expect(markerChannelManager.completeSendMarkersForTxnId(transactionalId))
+      .andAnswer(new IAnswer[Unit] {
+        override def answer(): Unit = {
+          completed = true
+        }
+      })
+      .once()
+    EasyMock.replay(markerChannelManager)
+
+    val response = new WriteTxnMarkersResponse(createPidErrorMap(error))
+    handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
+
+    assertTrue(txnMetadata.topicPartitions.isEmpty)
+    assertTrue(completed)
   }
 
-  private def verifyRetriesPartitionOnError(errors: Errors) = {
-    val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.UNKNOWN_TOPIC_OR_PARTITION))
-    val metadata = new TransactionMetadata(0, 0, 0, PrepareCommit, mutable.Set[TopicPartition](topic1), 0, 0)
+  private def verifyRemoveDelayedOperationOnError(error: Errors): Unit = {
 
-    EasyMock.expect(markerChannel.addRequestToSend(0, 0, 0, TransactionResult.COMMIT, 0, Set[TopicPartition](topic1)))
-    EasyMock.replay(markerChannel)
+    var removed = false
+    EasyMock.expect(markerChannelManager.removeMarkersForTxnId(transactionalId))
+      .andAnswer(new IAnswer[Unit] {
+        override def answer(): Unit = {
+          removed = true
+        }
+      })
+      .once()
+    EasyMock.replay(markerChannelManager)
 
+    val response = new WriteTxnMarkersResponse(createPidErrorMap(error))
     handler.onComplete(new ClientResponse(new RequestHeader(0, 0, "client", 1), null, null, 0, 0, false, null, response))
 
-    assertEquals(metadata.topicPartitions, mutable.Set[TopicPartition](topic1))
-    EasyMock.verify(markerChannel)
+    assertTrue(removed)
   }
 
+
   private def createPidErrorMap(errors: Errors) = {
     val pidMap = new java.util.HashMap[lang.Long, util.Map[TopicPartition, Errors]]()
     val errorsMap = new util.HashMap[TopicPartition, Errors]()
-    errorsMap.put(topic1, errors)
-    pidMap.put(0L, errorsMap)
+    errorsMap.put(topicPartition, errors)
+    pidMap.put(producerId, errorsMap)
     pidMap
   }
-
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index 2a14898..0250f60 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -22,13 +22,14 @@ import kafka.common.Topic
 import kafka.common.Topic.TransactionStateTopicName
 import kafka.log.Log
 import kafka.server.{FetchDataInfo, LogOffsetMetadata, ReplicaManager}
-import kafka.utils.{MockScheduler, ZkUtils}
+import kafka.utils.{MockScheduler, Pool, ZkUtils}
 import kafka.utils.TestUtils.fail
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.requests.IsolationLevel
 import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
+import org.apache.kafka.common.requests.TransactionResult
 import org.apache.kafka.common.utils.MockTime
 import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
 import org.junit.{After, Before, Test}
@@ -44,6 +45,7 @@ class TransactionStateManagerTest {
   val numPartitions = 2
   val transactionTimeoutMs: Int = 1000
   val topicPartition = new TopicPartition(TransactionStateTopicName, partitionId)
+  val coordinatorEpoch = 10
 
   val txnRecords: mutable.ArrayBuffer[SimpleRecord] = mutable.ArrayBuffer[SimpleRecord]()
 
@@ -95,10 +97,12 @@ class TransactionStateManagerTest {
 
   @Test
   def testAddGetPids() {
+    transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
+
     assertEquals(None, transactionManager.getTransactionState(txnId1))
-    assertEquals(txnMetadata1, transactionManager.addTransaction(txnId1, txnMetadata1))
-    assertEquals(Some(txnMetadata1), transactionManager.getTransactionState(txnId1))
-    assertEquals(txnMetadata1, transactionManager.addTransaction(txnId1, txnMetadata2))
+    assertEquals(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1), transactionManager.addTransaction(txnId1, txnMetadata1))
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
+    assertEquals(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1), transactionManager.addTransaction(txnId1, txnMetadata2))
   }
 
   @Test
@@ -110,19 +114,19 @@ class TransactionStateManagerTest {
     txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
       new TopicPartition("topic1", 1)))
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit()))
 
     // pid1's transaction adds three more partitions
     txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0),
       new TopicPartition("topic2", 1),
       new TopicPartition("topic2", 2)))
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit()))
 
     // pid1's transaction is preparing to commit
     txnMetadata1.state = PrepareCommit
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit()))
 
     // pid2's transaction started with three partitions
     txnMetadata2.state = Ongoing
@@ -130,23 +134,23 @@ class TransactionStateManagerTest {
       new TopicPartition("topic3", 1),
       new TopicPartition("topic3", 2)))
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit()))
 
     // pid2's transaction is preparing to abort
     txnMetadata2.state = PrepareAbort
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit()))
 
     // pid2's transaction has aborted
     txnMetadata2.state = CompleteAbort
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit()))
 
     // pid2's epoch has advanced, with no ongoing transaction yet
     txnMetadata2.state = Empty
     txnMetadata2.topicPartitions.clear()
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit()))
 
     val startOffset = 15L   // it should work for any start offset
     val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecords: _*)
@@ -157,7 +161,7 @@ class TransactionStateManagerTest {
     assertFalse(transactionManager.isCoordinatorFor(txnId1))
     assertFalse(transactionManager.isCoordinatorFor(txnId2))
 
-    transactionManager.loadTransactionsForPartition(partitionId, 0, _ => ())
+    transactionManager.loadTransactionsForTxnTopicPartition(partitionId, 0, (_, _, _, _, _) => ())
 
     // let the time advance to trigger the background thread loading
     scheduler.tick()
@@ -166,14 +170,14 @@ class TransactionStateManagerTest {
     val cachedPidMetadata2 = transactionManager.getTransactionState(txnId2).getOrElse(fail(txnId2 + "'s transaction state was not loaded into the cache"))
 
     // they should be equal to the latest status of the transaction
-    assertEquals(txnMetadata1, cachedPidMetadata1)
-    assertEquals(txnMetadata2, cachedPidMetadata2)
+    assertEquals(txnMetadata1, cachedPidMetadata1.transactionMetadata)
+    assertEquals(txnMetadata2, cachedPidMetadata2.transactionMetadata)
 
     // this partition should now be part of the owned partitions
     assertTrue(transactionManager.isCoordinatorFor(txnId1))
     assertTrue(transactionManager.isCoordinatorFor(txnId2))
 
-    transactionManager.removeTransactionsForPartition(partitionId)
+    transactionManager.removeTransactionsForTxnTopicPartition(partitionId)
 
     // let the time advance to trigger the background thread removing
     scheduler.tick()
@@ -187,6 +191,8 @@ class TransactionStateManagerTest {
 
   @Test
   def testAppendTransactionToLog() {
+    transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]())
+
     // first insert the initial transaction metadata
     transactionManager.addTransaction(txnId1, txnMetadata1)
 
@@ -194,78 +200,73 @@ class TransactionStateManagerTest {
     expectedError = Errors.NONE
 
     // update the metadata to ongoing with two partitions
-    val newMetadata = txnMetadata1.copy()
-    newMetadata.state = Ongoing
-    newMetadata.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
-      new TopicPartition("topic1", 1)))
-    txnMetadata1.prepareTransitionTo(Ongoing)
+    val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
+      new TopicPartition("topic1", 1)), time.milliseconds())
 
     // append the new metadata into log
-    transactionManager.appendTransactionToLog(txnId1, newMetadata, assertCallback)
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch, newMetadata, assertCallback)
 
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     // append to log again with expected failures
-    val failedMetadata = newMetadata.copy()
-    failedMetadata.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)))
+    val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds())
 
     // test COORDINATOR_NOT_AVAILABLE cases
     expectedError = Errors.COORDINATOR_NOT_AVAILABLE
 
     prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     prepareForTxnMessageAppend(Errors.REQUEST_TIMED_OUT)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     // test NOT_COORDINATOR cases
     expectedError = Errors.NOT_COORDINATOR
 
     prepareForTxnMessageAppend(Errors.NOT_LEADER_FOR_PARTITION)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     // test NOT_COORDINATOR cases
     expectedError = Errors.UNKNOWN
 
     prepareForTxnMessageAppend(Errors.MESSAGE_TOO_LARGE)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
 
     prepareForTxnMessageAppend(Errors.RECORD_LIST_TOO_LARGE)
-    transactionManager.appendTransactionToLog(txnId1, failedMetadata, assertCallback)
-    assertEquals(Some(newMetadata), transactionManager.getTransactionState(txnId1))
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, failedMetadata, assertCallback)
+    assertEquals(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), transactionManager.getTransactionState(txnId1))
   }
 
-  @Test(expected = classOf[IllegalStateException])
+  @Test
   def testAppendTransactionToLogWhileProducerFenced() = {
+    transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
+
     // first insert the initial transaction metadata
     transactionManager.addTransaction(txnId1, txnMetadata1)
 
     prepareForTxnMessageAppend(Errors.NONE)
-    expectedError = Errors.INVALID_PRODUCER_EPOCH
+    expectedError = Errors.NOT_COORDINATOR
 
-    val newMetadata = txnMetadata1.copy()
-    newMetadata.state = Ongoing
-    newMetadata.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
-      new TopicPartition("topic1", 1)))
-    txnMetadata1.prepareTransitionTo(Ongoing)
+    val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
+      new TopicPartition("topic1", 1)), time.milliseconds())
 
     // modify the cache while trying to append the new metadata
     txnMetadata1.producerEpoch = (txnMetadata1.producerEpoch + 1).toShort
 
     // append the new metadata into log
-    transactionManager.appendTransactionToLog(txnId1, newMetadata, assertCallback)
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, newMetadata, assertCallback)
   }
 
   @Test(expected = classOf[IllegalStateException])
@@ -276,38 +277,29 @@ class TransactionStateManagerTest {
     prepareForTxnMessageAppend(Errors.NONE)
     expectedError = Errors.INVALID_PRODUCER_EPOCH
 
-    val newMetadata = txnMetadata1.copy()
-    newMetadata.state = Ongoing
-    newMetadata.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
-      new TopicPartition("topic1", 1)))
-    txnMetadata1.prepareTransitionTo(Ongoing)
+    val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
+      new TopicPartition("topic1", 1)), time.milliseconds())
 
     // modify the cache while trying to append the new metadata
     txnMetadata1.pendingState = None
 
     // append the new metadata into log
-    transactionManager.appendTransactionToLog(txnId1, newMetadata, assertCallback)
-  }
-
-  @Test
-  def shouldReturnEpochForTransactionId(): Unit = {
-    val coordinatorEpoch = 10
-    EasyMock.expect(replicaManager.getLog(EasyMock.anyObject(classOf[TopicPartition]))).andReturn(None)
-    EasyMock.replay(replicaManager)
-    transactionManager.loadTransactionsForPartition(partitionId, coordinatorEpoch, _ => ())
-    val epoch = transactionManager.coordinatorEpochFor(txnId1).get
-    assertEquals(coordinatorEpoch, epoch)
+    transactionManager.appendTransactionToLog(txnId1, coordinatorEpoch = 10, newMetadata, assertCallback)
   }
 
   @Test
   def shouldReturnNoneIfTransactionIdPartitionNotOwned(): Unit = {
-    assertEquals(None, transactionManager.coordinatorEpochFor(txnId1))
+    assertEquals(None, transactionManager.getTransactionState(txnId1))
   }
 
   @Test
   def shouldOnlyConsiderTransactionsInTheOngoingStateForExpiry(): Unit = {
+    for (partitionId <- 0 until numPartitions) {
+      transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]())
+    }
+
     txnMetadata1.state = Ongoing
-    txnMetadata1.transactionStartTime = time.milliseconds()
+    txnMetadata1.txnStartTimestamp = time.milliseconds()
     transactionManager.addTransaction(txnId1, txnMetadata1)
     transactionManager.addTransaction(txnId2, txnMetadata2)
 
@@ -333,7 +325,7 @@ class TransactionStateManagerTest {
 
     time.sleep(2000)
     val expiring = transactionManager.transactionsToExpire()
-    assertEquals(List(TransactionalIdAndMetadata(txnId1, txnMetadata1)), expiring)
+    assertEquals(List(TransactionalIdAndProducerIdEpoch(txnId1, txnMetadata1.producerId, txnMetadata1.producerEpoch)), expiring)
   }
 
   @Test
@@ -351,17 +343,25 @@ class TransactionStateManagerTest {
     txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
       new TopicPartition("topic1", 1)))
 
-    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1))
+    txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit()))
     val startOffset = 0L
     val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecords: _*)
 
     prepareTxnLog(topicPartition, 0, records)
 
-    var receivedArgs: WriteTxnMarkerArgs = null
-    transactionManager.loadTransactionsForPartition(partitionId, 0, markerArgs => receivedArgs = markerArgs)
+    var txnId: String = null
+    def rememberTxnMarkers(transactionalId: String,
+                           coordinatorEpoch: Int,
+                           command: TransactionResult,
+                           metadata: TransactionMetadata,
+                           newMetadata: TransactionMetadataTransition): Unit = {
+      txnId = transactionalId
+    }
+
+    transactionManager.loadTransactionsForTxnTopicPartition(partitionId, 0, rememberTxnMarkers)
     scheduler.tick()
 
-    assertEquals(txnId1, receivedArgs.transactionalId)
+    assertEquals(txnId1, txnId)
   }
 
   private def assertCallback(error: Errors): Unit = {
@@ -414,5 +414,4 @@ class TransactionStateManagerTest {
 
     EasyMock.replay(replicaManager)
   }
-
 }


[3/3] kafka git commit: KAFKA-5130: Refactor transaction coordinator's in-memory cache; plus fixes on transaction metadata synchronization

Posted by gu...@apache.org.
KAFKA-5130: Refactor transaction coordinator's in-memory cache; plus fixes on transaction metadata synchronization

1. Collapsed the `ownedPartitions`, `pendingTxnMap` and the `transactionMetadataCache` into a single in-memory structure, which is a two-layered map: first keyed by the transactionTxnLog, and then valued with the current coordinatorEpoch of that map plus another map keyed by the transactional id.

2. Use `transactionalId` across the modules in transactional coordinator, attach this id with the transactional marker entries.

3. Use two keys: `transactionalId` and `txnLogPartitionId` in the writeMarkerPurgatory as well as passing it along with the TxnMarkerEntry, so that `TransactionMarkerRequestCompletionHandler` can use it to access the two-layered map upon getting responses.

4. Use one queue per `broker-id` and `txnLogPartitionId`. Also when there is a possible update on the end point associated with the `broker-id`, update the Node without clearing the queue but relying on the requests to retry in the next round.

5. Centralize the error handling callback for appending-metadata-to-log and sending-markers-to-brokers in `TransactionStateManager#appendTransactionToLog`, and `TransactionMarkerChannelManager#addTxnMarkersToSend`.

6. Always update the in-memory transaction metadata AFTER the txn log has been appended and replicated, and then double check on the cache to make sure nothing has changed since log appending. The only exception is when initializing the pid for the first time, in which we will put a dummy into the cache but set its pendingState as `Empty` (so it will be valid to transit from `Empty` to `Empty`) so that it can be updated after the log append has completed.

Author: Guozhang Wang <wa...@gmail.com>

Reviewers: Ismael Juma, Damian Guy, Jason Gustafson, Jun Rao

Closes #2964 from guozhangwang/K5130-refactor-tc-inmemory-cache


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/794e6dbd
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/794e6dbd
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/794e6dbd

Branch: refs/heads/trunk
Commit: 794e6dbd14f040d21d3402c5eda22cfa8f5c4b3d
Parents: 7baa58d
Author: Guozhang Wang <wa...@gmail.com>
Authored: Fri May 12 15:01:01 2017 -0700
Committer: Guozhang Wang <wa...@gmail.com>
Committed: Fri May 12 15:01:01 2017 -0700

----------------------------------------------------------------------
 .../requests/WriteTxnMarkersResponse.java       |   2 +
 .../kafka/common/InterBrokerSendThread.scala    |   3 +
 .../transaction/TransactionCoordinator.scala    | 525 ++++++++++---------
 .../transaction/TransactionLog.scala            |  14 +-
 .../transaction/TransactionMarkerChannel.scala  | 186 -------
 .../TransactionMarkerChannelManager.scala       | 243 +++++++--
 ...nsactionMarkerRequestCompletionHandler.scala | 130 +++--
 .../transaction/TransactionMetadata.scala       | 200 +++++--
 .../transaction/TransactionStateManager.scala   | 385 +++++++-------
 .../scala/kafka/server/DelayedOperation.scala   |   3 +-
 .../main/scala/kafka/server/KafkaConfig.scala   |   8 +-
 .../main/scala/kafka/server/MetadataCache.scala |  18 +
 .../scala/kafka/server/ReplicaManager.scala     |   6 +-
 .../TransactionCoordinatorIntegrationTest.scala |   6 +
 .../TransactionCoordinatorTest.scala            | 380 +++++---------
 .../transaction/TransactionLogTest.scala        |   8 +-
 .../TransactionMarkerChannelManagerTest.scala   | 304 +++++------
 .../TransactionMarkerChannelTest.scala          | 179 -------
 ...tionMarkerRequestCompletionHandlerTest.scala | 195 +++++--
 .../TransactionStateManagerTest.scala           | 139 +++--
 20 files changed, 1430 insertions(+), 1504 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersResponse.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersResponse.java
index 916dbab..00133a6 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersResponse.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersResponse.java
@@ -46,6 +46,8 @@ public class WriteTxnMarkersResponse extends AbstractResponse {
     //   NotEnoughReplicas
     //   NotEnoughReplicasAfterAppend
     //   InvalidRequiredAcks
+    //   TransactionCoordinatorFenced
+    //   RequestTimeout
 
     private final Map<Long, Map<TopicPartition, Errors>> errors;
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/common/InterBrokerSendThread.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala
index 217aa80..ac14243 100644
--- a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala
+++ b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala
@@ -33,6 +33,9 @@ class InterBrokerSendThread(name: String,
                             isInterruptible: Boolean = true)
   extends ShutdownableThread(name, isInterruptible) {
 
+  // visible for testing
+  def generateRequests(): Iterable[RequestAndCompletionHandler] = requestGenerator()
+
   override def doWork() {
     val now = time.milliseconds()
     var pollTimeout = Long.MaxValue

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index 982e009..7632f3f 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -48,19 +48,19 @@ object TransactionCoordinator {
       config.transactionTransactionsExpiredTransactionCleanupIntervalMs)
 
     val pidManager = new ProducerIdManager(config.brokerId, zkUtils)
-    val logManager = new TransactionStateManager(config.brokerId, zkUtils, scheduler, replicaManager, txnConfig, time)
-    val txnMarkerPurgatory = DelayedOperationPurgatory[DelayedTxnMarker]("txn-marker-purgatory", config.brokerId)
-    val transactionMarkerChannelManager = TransactionMarkerChannelManager(config, metrics, metadataCache, txnMarkerPurgatory, time)
+    val txnStateManager = new TransactionStateManager(config.brokerId, zkUtils, scheduler, replicaManager, txnConfig, time)
+    val txnMarkerPurgatory = DelayedOperationPurgatory[DelayedTxnMarker]("txn-marker-purgatory", config.brokerId, reaperEnabled = false)
+    val txnMarkerChannelManager = TransactionMarkerChannelManager(config, metrics, metadataCache, txnStateManager, txnMarkerPurgatory, time)
 
-    new TransactionCoordinator(config.brokerId, pidManager, logManager, transactionMarkerChannelManager, txnMarkerPurgatory, scheduler, time)
+    new TransactionCoordinator(config.brokerId, scheduler, pidManager, txnStateManager, txnMarkerChannelManager, txnMarkerPurgatory, time)
   }
 
   private def initTransactionError(error: Errors): InitPidResult = {
     InitPidResult(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, error)
   }
 
-  private def initTransactionMetadata(txnMetadata: TransactionMetadata): InitPidResult = {
-    InitPidResult(txnMetadata.pid, txnMetadata.producerEpoch, Errors.NONE)
+  private def initTransactionMetadata(txnMetadata: TransactionMetadataTransition): InitPidResult = {
+    InitPidResult(txnMetadata.producerId, txnMetadata.producerEpoch, Errors.NONE)
   }
 }
 
@@ -73,18 +73,18 @@ object TransactionCoordinator {
  * Producers with no specific transactional id may talk to a random broker as their coordinators.
  */
 class TransactionCoordinator(brokerId: Int,
+                             scheduler: Scheduler,
                              pidManager: ProducerIdManager,
                              txnManager: TransactionStateManager,
                              txnMarkerChannelManager: TransactionMarkerChannelManager,
                              txnMarkerPurgatory: DelayedOperationPurgatory[DelayedTxnMarker],
-                             scheduler: Scheduler,
                              time: Time) extends Logging {
   this.logIdent = "[Transaction Coordinator " + brokerId + "]: "
 
   import TransactionCoordinator._
 
   type InitPidCallback = InitPidResult => Unit
-  type TxnMetadataUpdateCallback = Errors => Unit
+  type AddPartitionsCallback = Errors => Unit
   type EndTxnCallback = Errors => Unit
 
   /* Active flag of the coordinator */
@@ -93,90 +93,113 @@ class TransactionCoordinator(brokerId: Int,
   def handleInitPid(transactionalId: String,
                     transactionTimeoutMs: Int,
                     responseCallback: InitPidCallback): Unit = {
-      if (transactionalId == null || transactionalId.isEmpty) {
-        // if the transactional id is not specified, then always blindly accept the request
-        // and return a new pid from the pid manager
-        val pid = pidManager.nextPid()
-        responseCallback(InitPidResult(pid, epoch = 0, Errors.NONE))
-      } else if (!txnManager.isCoordinatorFor(transactionalId)) {
-        // check if it is the assigned coordinator for the transactional id
-        responseCallback(initTransactionError(Errors.NOT_COORDINATOR))
-      } else if (txnManager.isCoordinatorLoadingInProgress(transactionalId)) {
-        responseCallback(initTransactionError(Errors.COORDINATOR_LOAD_IN_PROGRESS))
-      } else if (!txnManager.validateTransactionTimeoutMs(transactionTimeoutMs)) {
-        // check transactionTimeoutMs is not larger than the broker configured maximum allowed value
-        responseCallback(initTransactionError(Errors.INVALID_TRANSACTION_TIMEOUT))
-      } else {
-        // only try to get a new pid and update the cache if the transactional id is unknown
-        txnManager.getTransactionState(transactionalId) match {
-          case None =>
-            val pid = pidManager.nextPid()
-            val newMetadata: TransactionMetadata = new TransactionMetadata(pid = pid,
-              producerEpoch = 0,
-              txnTimeoutMs = transactionTimeoutMs,
-              state = Empty,
-              topicPartitions = collection.mutable.Set.empty[TopicPartition],
-              lastUpdateTimestamp = time.milliseconds())
-
-            val metadata = txnManager.addTransaction(transactionalId, newMetadata)
-
-            // there might be a concurrent thread that has just updated the mapping
-            // with the transactional id at the same time; in this case we will
-            // treat it as the metadata has existed and update it accordingly
-            metadata synchronized {
-              if (!metadata.eq(newMetadata))
-                initPidWithExistingMetadata(transactionalId, transactionTimeoutMs, responseCallback, metadata)
-              else
-                appendMetadataToLog(transactionalId, metadata, responseCallback)
 
+    if (transactionalId == null || transactionalId.isEmpty) {
+      // if the transactional id is not specified, then always blindly accept the request
+      // and return a new pid from the pid manager
+      val pid = pidManager.nextPid()
+      responseCallback(InitPidResult(pid, epoch = 0, Errors.NONE))
+    } else if (!txnManager.isCoordinatorFor(transactionalId)) {
+      // check if it is the assigned coordinator for the transactional id
+      responseCallback(initTransactionError(Errors.NOT_COORDINATOR))
+    } else if (txnManager.isCoordinatorLoadingInProgress(transactionalId)) {
+      responseCallback(initTransactionError(Errors.COORDINATOR_LOAD_IN_PROGRESS))
+    } else if (!txnManager.validateTransactionTimeoutMs(transactionTimeoutMs)) {
+      // check transactionTimeoutMs is not larger than the broker configured maximum allowed value
+      responseCallback(initTransactionError(Errors.INVALID_TRANSACTION_TIMEOUT))
+    } else {
+      // only try to get a new pid and update the cache if the transactional id is unknown
+      val result: Either[InitPidResult, (Int, TransactionMetadataTransition)] = txnManager.getTransactionState(transactionalId) match {
+        case None =>
+          val pid = pidManager.nextPid()
+          val now = time.milliseconds()
+          val createdMetadata = new TransactionMetadata(producerId = pid,
+            producerEpoch = 0,
+            txnTimeoutMs = transactionTimeoutMs,
+            state = Empty,
+            topicPartitions = collection.mutable.Set.empty[TopicPartition],
+            txnLastUpdateTimestamp = now)
+
+          val epochAndMetadata = txnManager.addTransaction(transactionalId, createdMetadata)
+          val coordinatorEpoch = epochAndMetadata.coordinatorEpoch
+          val txnMetadata = epochAndMetadata.transactionMetadata
+
+          // there might be a concurrent thread that has just updated the mapping
+          // with the transactional id at the same time (hence reference equality will fail);
+          // in this case we will treat it as the metadata has existed already
+          txnMetadata synchronized {
+            if (!txnMetadata.eq(createdMetadata)) {
+              initPidWithExistingMetadata(transactionalId, transactionTimeoutMs, coordinatorEpoch, txnMetadata)
+            } else {
+              Right(coordinatorEpoch, txnMetadata.prepareNewPid(time.milliseconds()))
             }
-          case Some(metadata) =>
-            initPidWithExistingMetadata(transactionalId, transactionTimeoutMs, responseCallback, metadata)
-        }
+          }
+
+        case Some(existingEpochAndMetadata) =>
+          val coordinatorEpoch = existingEpochAndMetadata.coordinatorEpoch
+          val txnMetadata = existingEpochAndMetadata.transactionMetadata
+
+          txnMetadata synchronized {
+            initPidWithExistingMetadata(transactionalId, transactionTimeoutMs, coordinatorEpoch, txnMetadata)
+          }
       }
-  }
 
-  private def appendMetadataToLog(transactionalId: String,
-                             metadata: TransactionMetadata,
-                             initPidCallback: InitPidCallback): Unit ={
-    def callback(errors: Errors): Unit = {
-      if (errors == Errors.NONE)
-        initPidCallback(initTransactionMetadata(metadata))
-      else
-        initPidCallback(initTransactionError(errors))
+      result match {
+        case Left(pidResult) =>
+          responseCallback(pidResult)
+
+        case Right((coordinatorEpoch, newMetadata)) =>
+          if (newMetadata.txnState == Ongoing) {
+            // abort the ongoing transaction and then return CONCURRENT_TRANSACTIONS to let client wait and retry
+            def sendRetriableErrorCallback(error: Errors): Unit = {
+              if (error != Errors.NONE) {
+                responseCallback(initTransactionError(error))
+              } else {
+                responseCallback(initTransactionError(Errors.CONCURRENT_TRANSACTIONS))
+              }
+            }
+
+            handleEndTransaction(transactionalId,
+              newMetadata.producerId,
+              newMetadata.producerEpoch,
+              TransactionResult.ABORT,
+              sendRetriableErrorCallback)
+          } else {
+            def sendPidResponseCallback(error: Errors): Unit = {
+              if (error == Errors.NONE)
+                responseCallback(initTransactionMetadata(newMetadata))
+              else
+                responseCallback(initTransactionError(error))
+            }
+
+            txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, newMetadata, sendPidResponseCallback)
+          }
+      }
     }
-    txnManager.appendTransactionToLog(transactionalId, metadata, callback)
   }
 
-
   private def initPidWithExistingMetadata(transactionalId: String,
                                           transactionTimeoutMs: Int,
-                                          responseCallback: InitPidCallback,
-                                          metadata: TransactionMetadata) = {
-
-    metadata synchronized {
-      if (metadata.state == Ongoing) {
-        // abort the ongoing transaction
-        handleEndTransaction(transactionalId,
-          metadata.pid,
-          metadata.producerEpoch,
-          TransactionResult.ABORT,
-          (errors: Errors) => {
-            if (errors != Errors.NONE) {
-              responseCallback(initTransactionError(errors))
-            } else {
-              responseCallback(initTransactionError(Errors.CONCURRENT_TRANSACTIONS))
-            }
-          })
-      } else if (metadata.state == PrepareAbort || metadata.state == PrepareCommit) {
-        responseCallback(initTransactionError(Errors.CONCURRENT_TRANSACTIONS))
-      } else {
-        metadata.producerEpoch = (metadata.producerEpoch + 1).toShort
-        metadata.txnTimeoutMs = transactionTimeoutMs
-        metadata.topicPartitions.clear()
-        metadata.lastUpdateTimestamp = time.milliseconds()
-        metadata.state = Empty
-        appendMetadataToLog(transactionalId, metadata, responseCallback)
+                                          coordinatorEpoch: Int,
+                                          txnMetadata: TransactionMetadata): Either[InitPidResult, (Int, TransactionMetadataTransition)] = {
+
+    if (txnMetadata.pendingTransitionInProgress) {
+      // return a retriable exception to let the client backoff and retry
+      Left(initTransactionError(Errors.CONCURRENT_TRANSACTIONS))
+    } else {
+      // caller should have synchronized on txnMetadata already
+      txnMetadata.state match {
+        case PrepareAbort | PrepareCommit =>
+          // reply to client and let client backoff and retry
+          Left(initTransactionError(Errors.CONCURRENT_TRANSACTIONS))
+
+        case CompleteAbort | CompleteCommit | Empty =>
+          // try to append and then update
+          Right(coordinatorEpoch, txnMetadata.prepareIncrementProducerEpoch(transactionTimeoutMs, time.milliseconds()))
+
+        case Ongoing =>
+          // indicate to abort the current ongoing txn first
+          Right(coordinatorEpoch, txnMetadata.prepareNoTransit())
       }
     }
   }
@@ -196,189 +219,192 @@ class TransactionCoordinator(brokerId: Int,
                                        pid: Long,
                                        epoch: Short,
                                        partitions: collection.Set[TopicPartition],
-                                       responseCallback: TxnMetadataUpdateCallback): Unit = {
-    val errors = validateTransactionalId(transactionalId)
-    if (errors != Errors.NONE)
-      responseCallback(errors)
-    else {
+                                       responseCallback: AddPartitionsCallback): Unit = {
+    val error = validateTransactionalId(transactionalId)
+    if (error != Errors.NONE) {
+      responseCallback(error)
+    } else {
       // try to update the transaction metadata and append the updated metadata to txn log;
       // if there is no such metadata treat it as invalid pid mapping error.
-      val (error, newMetadata) = txnManager.getTransactionState(transactionalId) match {
+      val result: Either[Errors, (Int, TransactionMetadataTransition)] = txnManager.getTransactionState(transactionalId) match {
         case None =>
-          (Errors.INVALID_PID_MAPPING, null)
+          Left(Errors.INVALID_PID_MAPPING)
+
+        case Some(epochAndMetadata) =>
+          val coordinatorEpoch = epochAndMetadata.coordinatorEpoch
+          val txnMetadata = epochAndMetadata.transactionMetadata
 
-        case Some(metadata) =>
           // generate the new transaction metadata with added partitions
-          metadata synchronized {
-            if (metadata.pid != pid) {
-              (Errors.INVALID_PID_MAPPING, null)
-            } else if (metadata.producerEpoch != epoch) {
-              (Errors.INVALID_PRODUCER_EPOCH, null)
-            } else if (metadata.pendingState.isDefined) {
+          txnMetadata synchronized {
+            if (txnMetadata.producerId != pid) {
+              Left(Errors.INVALID_PID_MAPPING)
+            } else if (txnMetadata.producerEpoch != epoch) {
+              Left(Errors.INVALID_PRODUCER_EPOCH)
+            } else if (txnMetadata.pendingTransitionInProgress) {
               // return a retriable exception to let the client backoff and retry
-              (Errors.CONCURRENT_TRANSACTIONS, null)
-            } else if (metadata.state == PrepareCommit || metadata.state == PrepareAbort) {
-              (Errors.CONCURRENT_TRANSACTIONS, null)
+              Left(Errors.CONCURRENT_TRANSACTIONS)
+            } else if (txnMetadata.state == PrepareCommit || txnMetadata.state == PrepareAbort) {
+              Left(Errors.CONCURRENT_TRANSACTIONS)
+            } else if (partitions.subsetOf(txnMetadata.topicPartitions)) {
+              // this is an optimization: if the partitions are already in the metadata reply OK immediately
+              Left(Errors.NONE)
             } else {
-              if (metadata.state == CompleteAbort || metadata.state == CompleteCommit)
-                metadata.topicPartitions.clear()
-              if (partitions.subsetOf(metadata.topicPartitions)) {
-                // this is an optimization: if the partitions are already in the metadata reply OK immediately
-                (Errors.NONE, null)
-              } else {
-                val now = time.milliseconds()
-                val newMetadata = new TransactionMetadata(pid,
-                  epoch,
-                  metadata.txnTimeoutMs,
-                  Ongoing,
-                  metadata.topicPartitions ++ partitions,
-                  if (metadata.state == Empty || metadata.state == CompleteCommit || metadata.state == CompleteAbort)
-                    now
-                  else metadata.transactionStartTime,
-                  now)
-                metadata.prepareTransitionTo(Ongoing)
-                (Errors.NONE, newMetadata)
-              }
+              Right(coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds()))
             }
           }
       }
 
-      if (newMetadata != null) {
-        txnManager.appendTransactionToLog(transactionalId, newMetadata, responseCallback)
-      } else {
-        responseCallback(error)
+      result match {
+        case Left(err) =>
+          responseCallback(err)
+
+        case Right((coordinatorEpoch, newMetadata)) =>
+          txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, newMetadata, responseCallback)
       }
     }
   }
 
-  def handleTxnImmigration(transactionStateTopicPartitionId: Int, coordinatorEpoch: Int) {
-      txnManager.loadTransactionsForPartition(transactionStateTopicPartitionId, coordinatorEpoch, writeTxnMarkers)
+  def handleTxnImmigration(txnTopicPartitionId: Int, coordinatorEpoch: Int) {
+      txnManager.loadTransactionsForTxnTopicPartition(txnTopicPartitionId, coordinatorEpoch, txnMarkerChannelManager.addTxnMarkersToSend)
   }
 
-  def handleTxnEmigration(transactionStateTopicPartitionId: Int) {
-      txnManager.removeTransactionsForPartition(transactionStateTopicPartitionId)
-      txnMarkerChannelManager.removeStateForPartition(transactionStateTopicPartitionId)
+  def handleTxnEmigration(txnTopicPartitionId: Int) {
+      txnManager.removeTransactionsForTxnTopicPartition(txnTopicPartitionId)
+      txnMarkerChannelManager.removeMarkersForTxnTopicPartition(txnTopicPartitionId)
   }
 
   def handleEndTransaction(transactionalId: String,
                            pid: Long,
                            epoch: Short,
-                           command: TransactionResult,
+                           txnMarkerResult: TransactionResult,
                            responseCallback: EndTxnCallback): Unit = {
-    val errors = validateTransactionalId(transactionalId)
-    if (errors != Errors.NONE)
-      responseCallback(errors)
-    else
-      txnManager.getTransactionState(transactionalId) match {
+    val error = validateTransactionalId(transactionalId)
+    if (error != Errors.NONE)
+      responseCallback(error)
+    else {
+      val preAppendResult: Either[Errors, (Int, TransactionMetadataTransition)] = txnManager.getTransactionState(transactionalId) match {
         case None =>
-          responseCallback(Errors.INVALID_PID_MAPPING)
-        case Some(metadata) =>
-          metadata synchronized {
-            if (metadata.pid != pid)
-              responseCallback(Errors.INVALID_PID_MAPPING)
-            else if (metadata.producerEpoch != epoch)
-              responseCallback(Errors.INVALID_PRODUCER_EPOCH)
-            else metadata.state match {
+          Left(Errors.INVALID_PID_MAPPING)
+
+        case Some(epochAndTxnMetadata) =>
+          val txnMetadata = epochAndTxnMetadata.transactionMetadata
+          val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch
+
+          txnMetadata synchronized {
+            if (txnMetadata.producerId != pid)
+              Left(Errors.INVALID_PID_MAPPING)
+            else if (txnMetadata.producerEpoch != epoch)
+              Left(Errors.INVALID_PRODUCER_EPOCH)
+            else if (txnMetadata.pendingTransitionInProgress)
+              Left(Errors.CONCURRENT_TRANSACTIONS)
+            else txnMetadata.state match {
               case Ongoing =>
-                commitOrAbort(transactionalId, pid, epoch, command, responseCallback, metadata)
+                val nextState = if (txnMarkerResult == TransactionResult.COMMIT)
+                  PrepareCommit
+                else
+                  PrepareAbort
+                Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, time.milliseconds()))
               case CompleteCommit =>
-                if (command == TransactionResult.COMMIT)
-                  responseCallback(Errors.NONE)
+                if (txnMarkerResult == TransactionResult.COMMIT)
+                  Left(Errors.NONE)
                 else
-                  responseCallback(Errors.INVALID_TXN_STATE)
+                  Left(Errors.INVALID_TXN_STATE)
               case CompleteAbort =>
-                if (command == TransactionResult.ABORT)
-                  responseCallback(Errors.NONE)
+                if (txnMarkerResult == TransactionResult.ABORT)
+                  Left(Errors.NONE)
                 else
-                  responseCallback(Errors.INVALID_TXN_STATE)
-              case _ =>
-                responseCallback(Errors.INVALID_TXN_STATE)
+                  Left(Errors.INVALID_TXN_STATE)
+              case PrepareCommit =>
+                if (txnMarkerResult == TransactionResult.COMMIT)
+                  Left(Errors.CONCURRENT_TRANSACTIONS)
+                else
+                  Left(Errors.INVALID_TXN_STATE)
+              case PrepareAbort =>
+                if (txnMarkerResult == TransactionResult.ABORT)
+                  Left(Errors.CONCURRENT_TRANSACTIONS)
+                else
+                  Left(Errors.INVALID_TXN_STATE)
+              case Empty =>
+                Left(Errors.INVALID_TXN_STATE)
             }
           }
       }
-  }
 
-  private def commitOrAbort(transactionalId: String,
-                            pid: Long,
-                            epoch: Short,
-                            command: TransactionResult,
-                            responseCallback: EndTxnCallback,
-                            metadata: TransactionMetadata) = {
-    val nextState = if (command == TransactionResult.COMMIT) PrepareCommit else PrepareAbort
-    val newMetadata = new TransactionMetadata(pid,
-      epoch,
-      metadata.txnTimeoutMs,
-      nextState,
-      metadata.topicPartitions,
-      metadata.transactionStartTime,
-      time.milliseconds())
-    metadata.prepareTransitionTo(nextState)
-
-    def logAppendCallback(errors: Errors): Unit = {
-      // we can respond to the client immediately and continue to write the txn markers if
-      // the log append was successful
-      responseCallback(errors)
-      if (errors == Errors.NONE)
-        txnManager.coordinatorEpochFor(transactionalId) match {
-          case Some(coordinatorEpoch) =>
-            writeTxnMarkers(WriteTxnMarkerArgs(transactionalId, pid, epoch, nextState, newMetadata, coordinatorEpoch))
-          case None =>
-            // this one should be completed by the new coordinator
-            warn(s"no longer the coordinator for transactionalId: $transactionalId")
-        }
-    }
-    txnManager.appendTransactionToLog(transactionalId, newMetadata, logAppendCallback)
-  }
+      preAppendResult match {
+        case Left(err) =>
+          responseCallback(err)
+
+        case Right((coordinatorEpoch, newMetadata)) =>
+          def sendTxnMarkersCallback(error: Errors): Unit = {
+            if (error == Errors.NONE) {
+              val preSendResult: Either[Errors, (TransactionMetadata, TransactionMetadataTransition)] = txnManager.getTransactionState(transactionalId) match {
+                case Some(epochAndMetadata) =>
+                  if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) {
+
+                    val txnMetadata = epochAndMetadata.transactionMetadata
+                    txnMetadata synchronized {
+                      if (txnMetadata.producerId != pid)
+                        Left(Errors.INVALID_PID_MAPPING)
+                      else if (txnMetadata.producerEpoch != epoch)
+                        Left(Errors.INVALID_PRODUCER_EPOCH)
+                      else if (txnMetadata.pendingTransitionInProgress)
+                        Left(Errors.CONCURRENT_TRANSACTIONS)
+                      else txnMetadata.state match {
+                        case Empty| Ongoing | CompleteCommit | CompleteAbort =>
+                          Left(Errors.INVALID_TXN_STATE)
+                        case PrepareCommit =>
+                          if (txnMarkerResult != TransactionResult.COMMIT)
+                            Left(Errors.INVALID_TXN_STATE)
+                          else
+                            Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
+                        case PrepareAbort =>
+                          if (txnMarkerResult != TransactionResult.ABORT)
+                            Left(Errors.INVALID_TXN_STATE)
+                          else
+                            Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
+                      }
+                    }
+                  } else {
+                    info(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId failed since the transaction coordinator epoch " +
+                      s"has been changed to ${epochAndMetadata.coordinatorEpoch} after the transaction metadata has been successfully appended to the log")
+
+                    Left(Errors.NOT_COORDINATOR)
+                  }
+
+                case None =>
+                  if (txnManager.isCoordinatorFor(transactionalId)) {
+                    throw new IllegalStateException("Cannot find the metadata in coordinator's cache while it is still the leader of the txn topic partition")
+                  } else {
+                    // this transactional id no longer exists, maybe the corresponding partition has already been migrated out.
+                    info(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId failed after the transaction message " +
+                      s"has been appended to the log. The partition ${partitionFor(transactionalId)} may have migrated as the metadata is no longer in the cache")
+
+                    Left(Errors.NOT_COORDINATOR)
+                  }
+              }
+
+              preSendResult match {
+                case Left(err) =>
+                  responseCallback(err)
+
+                case Right((txnMetadata, newPreSendMetadata)) =>
+                  // we can respond to the client immediately and continue to write the txn markers if
+                  // the log append was successful
+                  responseCallback(Errors.NONE)
 
-  private def writeTxnMarkers(markerArgs: WriteTxnMarkerArgs): Unit = {
-    def completionCallback(error: Errors): Unit = {
-      error match {
-        case Errors.NONE =>
-          txnManager.getTransactionState(markerArgs.transactionalId) match {
-            case Some(preparedCommitMetadata) =>
-              val completedState = if (markerArgs.nextState == PrepareCommit) CompleteCommit else CompleteAbort
-              val committedMetadata = new TransactionMetadata(markerArgs.pid,
-                markerArgs.epoch,
-                preparedCommitMetadata.txnTimeoutMs,
-                completedState,
-                preparedCommitMetadata.topicPartitions,
-                preparedCommitMetadata.transactionStartTime,
-                time.milliseconds())
-              preparedCommitMetadata.prepareTransitionTo(completedState)
-
-              def writeCommittedTransactionCallback(error: Errors): Unit = {
-                error match {
-                  case Errors.NONE =>
-                    txnMarkerChannelManager.removeCompleted(txnManager.partitionFor(markerArgs.transactionalId),
-                      markerArgs.pid)
-                  case Errors.NOT_COORDINATOR =>
-                    // this one should be completed by the new coordinator
-                    warn(s"no longer the coordinator for transactionalId: ${markerArgs.transactionalId}")
-                  case _ =>
-                    warn(s"error: $error caught for transactionalId: ${markerArgs.transactionalId} when appending state: $completedState. Retrying.")
-                    // retry until success
-                    txnManager.appendTransactionToLog(markerArgs.transactionalId, committedMetadata, writeCommittedTransactionCallback)
-                }
+                  txnMarkerChannelManager.addTxnMarkersToSend(transactionalId, coordinatorEpoch, txnMarkerResult, txnMetadata, newPreSendMetadata)
               }
-              txnManager.appendTransactionToLog(markerArgs.transactionalId, committedMetadata, writeCommittedTransactionCallback)
-            case None =>
-              // this one should be completed by the new coordinator
-              warn(s"no longer the coordinator for transactionalId: ${markerArgs.transactionalId}")
+            } else {
+              info(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId failed since the transaction message " +
+                s"cannot be appended to the log. Returning error code $error to the client")
+
+              responseCallback(error)
+            }
           }
-        case Errors.NOT_COORDINATOR =>
-          warn(s"no longer the coordinator for transactionalId: ${markerArgs.transactionalId}")
-        case _ =>
-          warn(s"error: $error caught when writing transaction markers for transactionalId: ${markerArgs.transactionalId}. retrying")
-          txnMarkerChannelManager.addTxnMarkerRequest(txnManager.partitionFor(markerArgs.transactionalId),
-            markerArgs.newMetadata,
-            markerArgs.coordinatorEpoch,
-            completionCallback)
+
+          txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, newMetadata, sendTxnMarkersCallback)
       }
     }
-    txnMarkerChannelManager.addTxnMarkerRequest(txnManager.partitionFor(markerArgs.transactionalId),
-      markerArgs.newMetadata,
-      markerArgs.coordinatorEpoch,
-      completionCallback)
   }
 
   def transactionTopicConfigs: Properties = txnManager.transactionTopicConfigs
@@ -386,30 +412,15 @@ class TransactionCoordinator(brokerId: Int,
   def partitionFor(transactionalId: String): Int = txnManager.partitionFor(transactionalId)
 
   private def expireTransactions(): Unit = {
-
-    txnManager.transactionsToExpire().foreach{ idAndMetadata =>
-      idAndMetadata.metadata synchronized {
-        if (!txnManager.isCoordinatorLoadingInProgress(idAndMetadata.transactionalId)
-          && idAndMetadata.metadata.pendingState.isEmpty) {
-          // bump the producerEpoch so that any further requests for this transactionalId will be fenced
-          idAndMetadata.metadata.producerEpoch = (idAndMetadata.metadata.producerEpoch + 1).toShort
-          idAndMetadata.metadata.prepareTransitionTo(Ongoing)
-          txnManager.appendTransactionToLog(idAndMetadata.transactionalId, idAndMetadata.metadata, (errors: Errors) => {
-            if (errors != Errors.NONE)
-              warn(s"failed to append transactionalId ${idAndMetadata.transactionalId} to log during transaction expiry. errors:$errors")
-            else
-              handleEndTransaction(idAndMetadata.transactionalId,
-                idAndMetadata.metadata.pid,
-                idAndMetadata.metadata.producerEpoch,
-                TransactionResult.ABORT,
-                (errors: Errors) => {
-                  if (errors != Errors.NONE)
-                    warn(s"rollback of transactionalId: ${idAndMetadata.transactionalId} failed during transaction expiry. errors: $errors")
-                }
-              )
-          })
-        }
-      }
+    txnManager.transactionsToExpire().foreach { txnIdAndPidEpoch =>
+      handleEndTransaction(txnIdAndPidEpoch.transactionalId,
+        txnIdAndPidEpoch.producerId,
+        txnIdAndPidEpoch.producerEpoch,
+        TransactionResult.ABORT,
+        (error: Errors) => {
+          if (error != Errors.NONE)
+            warn(s"Rollback ongoing transaction of transactionalId: ${txnIdAndPidEpoch.transactionalId} aborted due to ${error.exceptionName()}")
+        })
     }
   }
 
@@ -421,8 +432,8 @@ class TransactionCoordinator(brokerId: Int,
     scheduler.startup()
     scheduler.schedule("transaction-expiration",
       expireTransactions,
-      TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs,
-      TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs
+      TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs,
+      TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs
     )
     if (enablePidExpiration)
       txnManager.enablePidExpiration()
@@ -449,9 +460,3 @@ class TransactionCoordinator(brokerId: Int,
 }
 
 case class InitPidResult(pid: Long, epoch: Short, error: Errors)
-case class WriteTxnMarkerArgs(transactionalId: String,
-                              pid: Long,
-                              epoch: Short,
-                              nextState: TransactionState,
-                              newMetadata: TransactionMetadata,
-                              coordinatorEpoch: Int)

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
index 4a0dc71..a180502 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
@@ -141,18 +141,18 @@ object TransactionLog {
     *
     * @return value payload bytes
     */
-  private[coordinator] def valueToBytes(txnMetadata: TransactionMetadata): Array[Byte] = {
+  private[coordinator] def valueToBytes(txnMetadata: TransactionMetadataTransition): Array[Byte] = {
     val value = new Struct(CURRENT_VALUE_SCHEMA)
-    value.set(VALUE_SCHEMA_PID_FIELD, txnMetadata.pid)
+    value.set(VALUE_SCHEMA_PID_FIELD, txnMetadata.producerId)
     value.set(VALUE_SCHEMA_EPOCH_FIELD, txnMetadata.producerEpoch)
     value.set(VALUE_SCHEMA_TXN_TIMEOUT_FIELD, txnMetadata.txnTimeoutMs)
-    value.set(VALUE_SCHEMA_TXN_STATUS_FIELD, txnMetadata.state.byte)
-    value.set(VALUE_SCHEMA_TXN_ENTRY_TIMESTAMP_FIELD, txnMetadata.lastUpdateTimestamp)
-    value.set(VALUE_SCHEMA_TXN_START_TIMESTAMP_FIELD, txnMetadata.transactionStartTime)
+    value.set(VALUE_SCHEMA_TXN_STATUS_FIELD, txnMetadata.txnState.byte)
+    value.set(VALUE_SCHEMA_TXN_ENTRY_TIMESTAMP_FIELD, txnMetadata.txnLastUpdateTimestamp)
+    value.set(VALUE_SCHEMA_TXN_START_TIMESTAMP_FIELD, txnMetadata.txnStartTimestamp)
 
-    if (txnMetadata.state == Empty) {
+    if (txnMetadata.txnState == Empty) {
       if (txnMetadata.topicPartitions.nonEmpty)
-        throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.state}: $txnMetadata")
+        throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata")
 
       value.set(VALUE_SCHEMA_TXN_PARTITIONS_FIELD, null)
     } else {

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannel.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannel.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannel.scala
deleted file mode 100644
index e60bd40..0000000
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannel.scala
+++ /dev/null
@@ -1,186 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.coordinator.transaction
-
-import java.util
-import java.util.concurrent.{BlockingQueue, LinkedBlockingQueue}
-
-import kafka.common.{BrokerEndPointNotAvailableException, RequestAndCompletionHandler}
-import kafka.server.{DelayedOperationPurgatory, MetadataCache}
-import kafka.utils.Logging
-import org.apache.kafka.clients.NetworkClient
-import org.apache.kafka.common.network.ListenerName
-import org.apache.kafka.common.requests.{TransactionResult, WriteTxnMarkersRequest}
-import org.apache.kafka.common.requests.WriteTxnMarkersRequest.TxnMarkerEntry
-import org.apache.kafka.common.utils.Time
-import org.apache.kafka.common.{Node, TopicPartition}
-
-import scala.collection.{concurrent, immutable, mutable}
-import collection.JavaConverters._
-
-class TransactionMarkerChannel(interBrokerListenerName: ListenerName,
-                               metadataCache: MetadataCache,
-                               networkClient: NetworkClient,
-                               time: Time) extends Logging {
-
-  // we need the txnTopicPartition so we can clean up when Transaction Log partitions emigrate
-  case class PendingTxnKey(txnTopicPartition: Int, producerId: Long)
-
-  class BrokerRequestQueue(private var destination: Node) {
-
-    // keep track of the requests per txn topic partition so we can easily clear the queue
-    // during partition emigration
-    private val requestsPerTxnTopicPartition: concurrent.Map[Int, BlockingQueue[TxnMarkerEntry]]
-      = concurrent.TrieMap.empty[Int, BlockingQueue[TxnMarkerEntry]]
-
-    def removeRequestsForPartition(partition: Int): Unit = {
-      requestsPerTxnTopicPartition.remove(partition)
-    }
-
-    def maybeUpdateNode(node: Node): Unit = {
-      destination = node
-    }
-
-    def addRequests(txnTopicPartition: Int, txnMarkerEntry: TxnMarkerEntry): Unit = {
-      val queue = requestsPerTxnTopicPartition.getOrElseUpdate(txnTopicPartition, new LinkedBlockingQueue[TxnMarkerEntry]())
-      queue.add(txnMarkerEntry)
-    }
-
-    def eachMetadataPartition[B](f:(Int, BlockingQueue[TxnMarkerEntry]) => B): mutable.Iterable[B] =
-      requestsPerTxnTopicPartition.filter{ case(_, queue) => !queue.isEmpty}
-        .map{case(partition:Int, queue:BlockingQueue[TxnMarkerEntry]) => f(partition, queue)}
-
-
-    def node: Node = destination
-
-    def totalQueuedRequests(): Int =
-      requestsPerTxnTopicPartition.map { case(_, queue) => queue.size()}
-        .sum
-
-  }
-
-  private val brokerStateMap: concurrent.Map[Int, BrokerRequestQueue] = concurrent.TrieMap.empty[Int, BrokerRequestQueue]
-  private val pendingTxnMap: concurrent.Map[PendingTxnKey, TransactionMetadata] = concurrent.TrieMap.empty[PendingTxnKey, TransactionMetadata]
-
-  // TODO: What is reasonable for this
-  private val brokerNotAliveBackoffMs = 10
-
-  // visible for testing
-  private[transaction] def queueForBroker(brokerId: Int) = {
-    brokerStateMap.get(brokerId)
-  }
-
-  private[transaction]
-  def drainQueuedTransactionMarkers(txnMarkerPurgatory: DelayedOperationPurgatory[DelayedTxnMarker]): Iterable[RequestAndCompletionHandler] = {
-    brokerStateMap.flatMap {case (brokerId: Int, brokerRequestQueue: BrokerRequestQueue) =>
-      brokerRequestQueue.eachMetadataPartition{ case(partitionId, queue) =>
-        val markersToSend: java.util.List[TxnMarkerEntry] = new util.ArrayList[TxnMarkerEntry]()
-        queue.drainTo(markersToSend)
-        val requestCompletionHandler = new TransactionMarkerRequestCompletionHandler(this, txnMarkerPurgatory, partitionId, markersToSend, brokerId)
-        RequestAndCompletionHandler(brokerRequestQueue.node, new WriteTxnMarkersRequest.Builder(markersToSend), requestCompletionHandler)
-      }
-    }
-  }
-
-
-  def addOrUpdateBroker(broker: Node) {
-    brokerStateMap.putIfAbsent(broker.id(), new BrokerRequestQueue(broker)) match {
-      case Some(brokerQueue) => brokerQueue.maybeUpdateNode(broker)
-      case None => // nothing to do
-    }
-  }
-
-  private[transaction] def addRequestForBroker(brokerId: Int, metadataPartition: Int, txnMarkerEntry: TxnMarkerEntry) {
-    val brokerQueue = brokerStateMap(brokerId)
-    brokerQueue.addRequests(metadataPartition, txnMarkerEntry)
-    trace(s"Added marker $txnMarkerEntry for broker $brokerId")
-  }
-
-  def addRequestToSend(metadataPartition: Int, pid: Long, epoch: Short, result: TransactionResult, coordinatorEpoch: Int, topicPartitions: immutable.Set[TopicPartition]): Unit = {
-    val partitionsByDestination: immutable.Map[Int, immutable.Set[TopicPartition]] = topicPartitions.groupBy { topicPartition: TopicPartition =>
-      val currentBrokers = mutable.Set.empty[Int]
-      var brokerId:Option[Int] = None
-
-      while(brokerId.isEmpty) {
-        val leaderForPartition = metadataCache.getPartitionInfo(topicPartition.topic, topicPartition.partition)
-        leaderForPartition match {
-          case Some(partitionInfo) =>
-            val leaderId = partitionInfo.leaderIsrAndControllerEpoch.leaderAndIsr.leader
-            if (currentBrokers.add(leaderId)) {
-              try {
-                metadataCache.getAliveEndpoint(leaderId, interBrokerListenerName) match {
-                  case Some(broker) =>
-                    addOrUpdateBroker(broker)
-                    brokerId = Some(leaderId)
-                  case None =>
-                    currentBrokers.remove(leaderId)
-                    trace(s"alive endpoint for broker with id: $leaderId not available. retrying")
-
-                }
-              } catch {
-                case _:BrokerEndPointNotAvailableException =>
-                  currentBrokers.remove(leaderId)
-                  trace(s"alive endpoint for broker with id: $leaderId not available. retrying")
-              }
-            }
-          case None =>
-            trace(s"couldn't find leader for partition: $topicPartition")
-        }
-        if (brokerId.isEmpty)
-          time.sleep(brokerNotAliveBackoffMs)
-      }
-      brokerId.get
-    }
-
-    for ((brokerId: Int, topicPartitions: immutable.Set[TopicPartition]) <- partitionsByDestination) {
-      val txnMarker = new TxnMarkerEntry(pid, epoch, coordinatorEpoch, result, topicPartitions.toList.asJava)
-      addRequestForBroker(brokerId, metadataPartition, txnMarker)
-    }
-    networkClient.wakeup()
-  }
-
-  def maybeAddPendingRequest(metadataPartition: Int, metadata: TransactionMetadata): Boolean = {
-    val existingMetadataToWrite = pendingTxnMap.putIfAbsent(PendingTxnKey(metadataPartition, metadata.pid), metadata)
-    existingMetadataToWrite.isEmpty
-  }
-
-  def removeCompletedTxn(metadataPartition: Int, pid: Long): Unit = {
-    pendingTxnMap.remove(PendingTxnKey(metadataPartition, pid))
-  }
-
-  def pendingTxnMetadata(metadataPartition: Int, pid: Long): Option[TransactionMetadata] = {
-    pendingTxnMap.get(PendingTxnKey(metadataPartition, pid))
-  }
-
-  def close(): Unit = {
-    brokerStateMap.clear()
-    pendingTxnMap.clear()
-    networkClient.close()
-  }
-
-  def removeStateForPartition(partition: Int): mutable.Iterable[Long] = {
-    brokerStateMap.foreach { case(_, brokerQueue) =>
-      brokerQueue.removeRequestsForPartition(partition)
-    }
-    pendingTxnMap.filter { case (key: PendingTxnKey, _) => key.txnTopicPartition == partition }
-      .map { case (key: PendingTxnKey, _) =>
-        pendingTxnMap.remove(key)
-        key.producerId
-      }
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
index 1b7ea56..b7a2e80 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
@@ -17,28 +17,33 @@
 package kafka.coordinator.transaction
 
 
+import java.util
+import java.util.concurrent.{BlockingQueue, LinkedBlockingQueue}
+
 import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler}
 import kafka.server.{DelayedOperationPurgatory, KafkaConfig, MetadataCache}
 import kafka.utils.Logging
 import org.apache.kafka.clients._
+import org.apache.kafka.common.{Node, TopicPartition}
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.network._
-import org.apache.kafka.common.requests.TransactionResult
+import org.apache.kafka.common.requests.{TransactionResult, WriteTxnMarkersRequest}
 import org.apache.kafka.common.security.JaasContext
 import org.apache.kafka.common.utils.Time
-
 import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.requests.WriteTxnMarkersRequest.TxnMarkerEntry
 
 import collection.JavaConverters._
+import scala.collection.{concurrent, immutable, mutable}
 
 object TransactionMarkerChannelManager {
   def apply(config: KafkaConfig,
             metrics: Metrics,
             metadataCache: MetadataCache,
+            txnStateManager: TransactionStateManager,
             txnMarkerPurgatory: DelayedOperationPurgatory[DelayedTxnMarker],
             time: Time): TransactionMarkerChannelManager = {
 
-
     val channelBuilder = ChannelBuilders.clientChannelBuilder(
       config.interBrokerSecurityProtocol,
       JaasContext.Type.SERVER,
@@ -47,7 +52,6 @@ object TransactionMarkerChannelManager {
       config.saslMechanismInterBrokerProtocol,
       config.saslInterBrokerHandshakeRequestEnable
     )
-    val threadName = "TxnMarkerSenderThread-" + config.brokerId
     val selector = new Selector(
       NetworkReceive.UNLIMITED,
       config.connectionsMaxIdleMs,
@@ -61,7 +65,7 @@ object TransactionMarkerChannelManager {
     val networkClient = new NetworkClient(
       selector,
       new ManualMetadataUpdater(),
-      threadName,
+      s"broker-${config.brokerId}-txn-marker-sender",
       1,
       50,
       Selectable.USE_DEFAULT_BUFFER_SIZE,
@@ -71,79 +75,226 @@ object TransactionMarkerChannelManager {
       false,
       new ApiVersions
     )
-    val channel = new TransactionMarkerChannel(config.interBrokerListenerName, metadataCache, networkClient, time)
-
-    val sendThread: InterBrokerSendThread = {
-      networkClient.wakeup()
-      new InterBrokerSendThread(threadName, networkClient, requestGenerator(channel, txnMarkerPurgatory), time)
-    }
 
     new TransactionMarkerChannelManager(config,
       metadataCache,
+      networkClient,
+      txnStateManager,
       txnMarkerPurgatory,
-      sendThread,
-      channel)
+      time)
   }
 
-  private[transaction] def requestGenerator(transactionMarkerChannel: TransactionMarkerChannel,
-                                            txnMarkerPurgatory: DelayedOperationPurgatory[DelayedTxnMarker]): () => Iterable[RequestAndCompletionHandler] = {
-    () => transactionMarkerChannel.drainQueuedTransactionMarkers(txnMarkerPurgatory)
+  private[transaction] def requestGenerator(transactionMarkerChannelManager: TransactionMarkerChannelManager): () => Iterable[RequestAndCompletionHandler] = {
+    () => transactionMarkerChannelManager.drainQueuedTransactionMarkers()
   }
 }
 
+class TxnMarkerQueue(@volatile private var destination: Node) {
+
+  // keep track of the requests per txn topic partition so we can easily clear the queue
+  // during partition emigration
+  private val markersPerTxnTopicPartition: concurrent.Map[Int, BlockingQueue[TxnIdAndMarkerEntry]]
+  = concurrent.TrieMap.empty[Int, BlockingQueue[TxnIdAndMarkerEntry]]
+
+  def removeMarkersForTxnTopicPartition(partition: Int): Option[BlockingQueue[TxnIdAndMarkerEntry]] = {
+    markersPerTxnTopicPartition.remove(partition)
+  }
+
+  def maybeUpdateNode(node: Node): Unit = {
+    destination = node
+  }
+
+  def addMarkers(txnTopicPartition: Int, txnIdAndMarker: TxnIdAndMarkerEntry): Unit = {
+    val queue = markersPerTxnTopicPartition.getOrElseUpdate(txnTopicPartition, new LinkedBlockingQueue[TxnIdAndMarkerEntry]())
+    queue.add(txnIdAndMarker)
+  }
+
+  def forEachTxnTopicPartition[B](f:(Int, BlockingQueue[TxnIdAndMarkerEntry]) => B): mutable.Iterable[B] =
+    markersPerTxnTopicPartition.filter { case(_, queue) => !queue.isEmpty }
+      .map { case(partition:Int, queue:BlockingQueue[TxnIdAndMarkerEntry]) => f(partition, queue) }
+
+  def node: Node = destination
 
+  // TODO: this function is only for metrics recording, not yet added
+  def totalNumMarkers(): Int = markersPerTxnTopicPartition.map { case(_, queue) => queue.size()}.sum
+
+  // visible for testing
+  def totalNumMarkers(txnTopicPartition: Int): Int = markersPerTxnTopicPartition.get(txnTopicPartition).fold(0)(_.size())
+}
 
 class TransactionMarkerChannelManager(config: KafkaConfig,
                                       metadataCache: MetadataCache,
+                                      networkClient: NetworkClient,
+                                      txnStateManager: TransactionStateManager,
                                       txnMarkerPurgatory: DelayedOperationPurgatory[DelayedTxnMarker],
-                                      interBrokerSendThread: InterBrokerSendThread,
-                                      transactionMarkerChannel: TransactionMarkerChannel) extends Logging {
+                                      time: Time) extends Logging {
+
+  private val markersQueuePerBroker: concurrent.Map[Int, TxnMarkerQueue] = concurrent.TrieMap.empty[Int, TxnMarkerQueue]
+
+  private val interBrokerListenerName: ListenerName = config.interBrokerListenerName
 
-  type WriteTxnMarkerCallback = Errors => Unit
+  // TODO: What is reasonable for this
+  private val brokerNotAliveBackoffMs = 10
+
+  private val txnMarkerSendThread: InterBrokerSendThread = {
+    new InterBrokerSendThread("TxnMarkerSenderThread-" + config.brokerId, networkClient, drainQueuedTransactionMarkers, time)
+  }
 
   def start(): Unit = {
-    interBrokerSendThread.start()
+    txnMarkerSendThread.start()
+    networkClient.wakeup()    // FIXME: is this really required?
   }
 
   def shutdown(): Unit = {
-    interBrokerSendThread.shutdown()
-    transactionMarkerChannel.close()
+    txnMarkerSendThread.shutdown()
+    markersQueuePerBroker.clear()
+  }
+
+  // visible for testing
+  private[transaction] def queueForBroker(brokerId: Int) = {
+    markersQueuePerBroker.get(brokerId)
+  }
+
+  // visible for testing
+  private[transaction] def senderThread = txnMarkerSendThread
+
+  private[transaction] def addMarkersForBroker(broker: Node, txnTopicPartition: Int, txnIdAndMarker: TxnIdAndMarkerEntry) {
+    val brokerId = broker.id
+
+    // we do not synchronize on the update of the broker node with the enqueuing,
+    // since even if there is a race condition we will just retry
+    val brokerRequestQueue = markersQueuePerBroker.getOrElseUpdate(brokerId, new TxnMarkerQueue(broker))
+    brokerRequestQueue.maybeUpdateNode(broker)
+    brokerRequestQueue.addMarkers(txnTopicPartition, txnIdAndMarker)
+
+    trace(s"Added marker ${txnIdAndMarker.txnMarkerEntry} for transactional id ${txnIdAndMarker.txnId} to destination broker $brokerId")
   }
 
+  private[transaction] def drainQueuedTransactionMarkers(): Iterable[RequestAndCompletionHandler] = {
+    markersQueuePerBroker.map { case (brokerId: Int, brokerRequestQueue: TxnMarkerQueue) =>
+      val txnIdAndMarkerEntries: java.util.List[TxnIdAndMarkerEntry] = new util.ArrayList[TxnIdAndMarkerEntry]()
+      brokerRequestQueue.forEachTxnTopicPartition { case (_, queue) =>
+        queue.drainTo(txnIdAndMarkerEntries)
+      }
+      (brokerRequestQueue.node, txnIdAndMarkerEntries)
+    }
+      .filter { case (_, entries) => !entries.isEmpty}
+      .map { case (node, entries) =>
+        val markersToSend: java.util.List[TxnMarkerEntry] = entries.asScala.map(_.txnMarkerEntry).asJava
+        val requestCompletionHandler = new TransactionMarkerRequestCompletionHandler(node.id, txnStateManager, this, entries)
+        RequestAndCompletionHandler(node, new WriteTxnMarkersRequest.Builder(markersToSend), requestCompletionHandler)
+      }
+  }
+
+  def addTxnMarkersToSend(transactionalId: String,
+                          coordinatorEpoch: Int,
+                          txnResult: TransactionResult,
+                          txnMetadata: TransactionMetadata,
+                          newMetadata: TransactionMetadataTransition): Unit = {
+
+    def appendToLogCallback(error: Errors): Unit = {
+      error match {
+        case Errors.NONE =>
+          trace(s"Completed sending transaction markers for $transactionalId as $txnResult")
+
+          txnStateManager.getTransactionState(transactionalId) match {
+            case Some(epochAndMetadata) =>
+              if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) {
+                debug(s"Updating $transactionalId's transaction state to $txnMetadata with coordinator epoch $coordinatorEpoch for $transactionalId succeeded")
 
-  def addTxnMarkerRequest(txnTopicPartition: Int, metadata: TransactionMetadata, coordinatorEpoch: Int, completionCallback: WriteTxnMarkerCallback): Unit = {
-    val metadataToWrite = metadata synchronized metadata.copy()
+                // try to append to the transaction log
+                def retryAppendCallback(error: Errors): Unit =
+                  error match {
+                    case Errors.NONE =>
+                      trace(s"Completed transaction for $transactionalId with coordinator epoch $coordinatorEpoch, final state: state after commit: ${txnMetadata.state}")
 
-    if (!transactionMarkerChannel.maybeAddPendingRequest(txnTopicPartition, metadata))
-      // TODO: Not sure this is the correct response here?
-      completionCallback(Errors.INVALID_TXN_STATE)
-    else {
-      val delayedTxnMarker = new DelayedTxnMarker(metadataToWrite, completionCallback)
-      txnMarkerPurgatory.tryCompleteElseWatch(delayedTxnMarker, Seq(metadata.pid))
+                    case Errors.NOT_COORDINATOR =>
+                      info(s"No longer the coordinator for transactionalId: $transactionalId while trying to append to transaction log, skip writing to transaction log")
 
-      val result = metadataToWrite.state match {
-        case PrepareCommit => TransactionResult.COMMIT
-        case PrepareAbort => TransactionResult.ABORT
-        case s => throw new IllegalStateException("Unexpected txn metadata state while writing markers: " + s)
+                    case Errors.COORDINATOR_NOT_AVAILABLE =>
+                      warn(s"Failed updating transaction state for $transactionalId when appending to transaction log due to ${error.exceptionName}. retrying")
+
+                      // retry appending
+                      txnStateManager.appendTransactionToLog(transactionalId, coordinatorEpoch, newMetadata, retryAppendCallback)
+
+                    case errors: Errors =>
+                      throw new IllegalStateException(s"Unexpected error ${errors.exceptionName} while appending to transaction log for $transactionalId")
+                  }
+
+                txnStateManager.appendTransactionToLog(transactionalId, coordinatorEpoch, newMetadata, retryAppendCallback)
+              } else {
+                info(s"Updating $transactionalId's transaction state to $txnMetadata with coordinator epoch $coordinatorEpoch for $transactionalId failed after the transaction markers " +
+                  s"has been sent to brokers. The cached metadata have been changed to $epochAndMetadata since preparing to send markers")
+              }
+
+            case None =>
+              // this transactional id no longer exists, maybe the corresponding partition has already been migrated out.
+              // we will stop appending the completed log entry to transaction topic as the new leader should be doing it.
+              info(s"Updating $transactionalId's transaction state to $txnMetadata with coordinator epoch $coordinatorEpoch for $transactionalId failed after the transaction message " +
+                s"has been appended to the log. The partition ${txnStateManager.partitionFor(transactionalId)} may have migrated as the metadata is no longer in the cache")
+          }
+
+        case other =>
+          throw new IllegalStateException(s"Unexpected error ${other.exceptionName} before appending to txn log for $transactionalId")
       }
-      transactionMarkerChannel.addRequestToSend(txnTopicPartition,
-        metadataToWrite.pid,
-        metadataToWrite.producerEpoch,
-        result,
-        coordinatorEpoch,
-        metadataToWrite.topicPartitions.toSet)
     }
+
+    // watch for both the transactional id and the transaction topic partition id,
+    // so we can cancel all the delayed operations for the same partition id;
+    // NOTE this is only possible because the hashcode of Int / String never overlaps
+
+    // TODO: if the delayed txn marker will always have infinite timeout, we can replace it with a map
+    val delayedTxnMarker = new DelayedTxnMarker(txnMetadata, appendToLogCallback)
+    val txnTopicPartition = txnStateManager.partitionFor(transactionalId)
+    txnMarkerPurgatory.tryCompleteElseWatch(delayedTxnMarker, Seq(transactionalId, txnTopicPartition))
+
+    addTxnMarkersToBrokerQueue(transactionalId, txnMetadata.producerId, txnMetadata.producerEpoch, txnResult, coordinatorEpoch, txnMetadata.topicPartitions.toSet)
   }
 
-  def removeCompleted(txnTopicPartition: Int, pid: Long): Unit = {
-    transactionMarkerChannel.removeCompletedTxn(txnTopicPartition, pid)
+  def addTxnMarkersToBrokerQueue(transactionalId: String, pid: Long, epoch: Short, result: TransactionResult, coordinatorEpoch: Int, topicPartitions: immutable.Set[TopicPartition]): Unit = {
+    val txnTopicPartition = txnStateManager.partitionFor(transactionalId)
+    val partitionsByDestination: immutable.Map[Node, immutable.Set[TopicPartition]] = topicPartitions.groupBy { topicPartition: TopicPartition =>
+      var brokerNode: Option[Node] = None
+
+      // TODO: instead of retry until succeed, we can first put it into an unknown broker queue and let the sender thread to look for its broker and migrate them
+      while (brokerNode.isEmpty) {
+        brokerNode = metadataCache.getPartitionLeaderEndpoint(topicPartition.topic, topicPartition.partition, interBrokerListenerName)
+
+        if (brokerNode.isEmpty) {
+          trace(s"Couldn't find leader endpoint for partition: $topicPartition, retrying.")
+          time.sleep(brokerNotAliveBackoffMs)
+        }
+      }
+      brokerNode.get
+    }
+
+    for ((broker: Node, topicPartitions: immutable.Set[TopicPartition]) <- partitionsByDestination) {
+      val txnIdAndMarker = TxnIdAndMarkerEntry(transactionalId, new TxnMarkerEntry(pid, epoch, coordinatorEpoch, result, topicPartitions.toList.asJava))
+      addMarkersForBroker(broker, txnTopicPartition, txnIdAndMarker)
+    }
+
+    networkClient.wakeup()
   }
 
-  def removeStateForPartition(transactionStateTopicPartitionId: Int): Unit = {
-    transactionMarkerChannel.removeStateForPartition(transactionStateTopicPartitionId)
-      .foreach{pid =>
-        txnMarkerPurgatory.cancelForKey(pid)
+  def removeMarkersForTxnTopicPartition(txnTopicPartitionId: Int): Unit = {
+    txnMarkerPurgatory.cancelForKey(txnTopicPartitionId)
+    markersQueuePerBroker.foreach { case(_, brokerQueue) =>
+      brokerQueue.removeMarkersForTxnTopicPartition(txnTopicPartitionId).foreach { queue =>
+        for (entry: TxnIdAndMarkerEntry <- queue.asScala)
+          removeMarkersForTxnId(entry.txnId)
       }
+    }
+  }
+
+  def removeMarkersForTxnId(transactionalId: String): Unit = {
+    // we do not need to clear the queue since it should have
+    // already been drained by the sender thread
+    txnMarkerPurgatory.cancelForKey(transactionalId)
   }
 
+  def completeSendMarkersForTxnId(transactionalId: String): Unit = {
+    txnMarkerPurgatory.checkAndComplete(transactionalId)
+  }
 }
+
+case class TxnIdAndMarkerEntry(txnId: String, txnMarkerEntry: TxnMarkerEntry)

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala
index 5d68325..5978a97 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala
@@ -22,25 +22,23 @@ import kafka.utils.Logging
 import org.apache.kafka.clients.{ClientResponse, RequestCompletionHandler}
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.requests.WriteTxnMarkersRequest.TxnMarkerEntry
 import org.apache.kafka.common.requests.WriteTxnMarkersResponse
 
 import scala.collection.mutable
 import collection.JavaConversions._
 
-class TransactionMarkerRequestCompletionHandler(transactionMarkerChannel: TransactionMarkerChannel,
-                                                txnMarkerPurgatory: DelayedOperationPurgatory[DelayedTxnMarker],
-                                                txnTopicPartition: Int,
-                                                txnMarkerEntries: java.util.List[TxnMarkerEntry],
-                                                brokerId: Int) extends RequestCompletionHandler with Logging {
+class TransactionMarkerRequestCompletionHandler(brokerId: Int,
+                                                txnStateManager: TransactionStateManager,
+                                                txnMarkerChannelManager: TransactionMarkerChannelManager,
+                                                txnIdAndMarkerEntries: java.util.List[TxnIdAndMarkerEntry]) extends RequestCompletionHandler with Logging {
   override def onComplete(response: ClientResponse): Unit = {
     val correlationId = response.requestHeader.correlationId
     if (response.wasDisconnected) {
       trace(s"Cancelled request $response due to node ${response.destination} being disconnected")
       // re-enqueue the markers
-      for (txnMarker: TxnMarkerEntry <- txnMarkerEntries) {
-        transactionMarkerChannel.addRequestToSend(
-          txnTopicPartition,
+      for (txnIdAndMarker: TxnIdAndMarkerEntry <- txnIdAndMarkerEntries) {
+        val txnMarker = txnIdAndMarker.txnMarkerEntry
+        txnMarkerChannelManager.addTxnMarkersToBrokerQueue(txnIdAndMarker.txnId,
           txnMarker.producerId(),
           txnMarker.producerEpoch(),
           txnMarker.transactionResult(),
@@ -52,43 +50,89 @@ class TransactionMarkerRequestCompletionHandler(transactionMarkerChannel: Transa
 
       val writeTxnMarkerResponse = response.responseBody.asInstanceOf[WriteTxnMarkersResponse]
 
-      for (txnMarker: TxnMarkerEntry <- txnMarkerEntries) {
-        val errors = writeTxnMarkerResponse.errors(txnMarker.producerId())
+      for (txnIdAndMarker: TxnIdAndMarkerEntry <- txnIdAndMarkerEntries) {
+        val transactionalId = txnIdAndMarker.txnId
+        val txnMarker = txnIdAndMarker.txnMarkerEntry
+        val errors = writeTxnMarkerResponse.errors(txnMarker.producerId)
 
         if (errors == null)
-          throw new IllegalStateException("WriteTxnMarkerResponse does not contain expected error map for pid " + txnMarker.producerId())
-
-        val retryPartitions: mutable.Set[TopicPartition] = mutable.Set.empty[TopicPartition]
-        for ((topicPartition: TopicPartition, error: Errors) <- errors) {
-          error match {
-            case Errors.NONE =>
-              transactionMarkerChannel.pendingTxnMetadata(txnTopicPartition, txnMarker.producerId()) match {
-                case None =>
-                  // TODO: probably need to respond with Errors.NOT_COORDINATOR
-                  throw new IllegalArgumentException(s"transaction metadata not found during write txn marker request. partition ${txnTopicPartition} has likely emigrated")
-                case Some(metadata) =>
-                  // do not synchronize on this metadata since it will only be accessed by the sender thread
-                  metadata.topicPartitions -= topicPartition
+          throw new IllegalStateException(s"WriteTxnMarkerResponse does not contain expected error map for pid ${txnMarker.producerId}")
+
+        txnStateManager.getTransactionState(transactionalId) match {
+          case None =>
+            info(s"Transaction topic partition for $transactionalId may likely has emigrated, as the corresponding metadata do not exist in the cache" +
+              s"any more; cancel sending transaction markers $txnMarker to the brokers")
+
+            // txn topic partition has likely emigrated, just cancel it from the purgatory
+            txnMarkerChannelManager.removeMarkersForTxnId(transactionalId)
+
+          case Some(epochAndMetadata) =>
+            val txnMetadata = epochAndMetadata.transactionMetadata
+            val retryPartitions: mutable.Set[TopicPartition] = mutable.Set.empty[TopicPartition]
+            var abortSending: Boolean = false
+
+            if (epochAndMetadata.coordinatorEpoch != txnMarker.coordinatorEpoch) {
+              // coordinator epoch has changed, just cancel it from the purgatory
+              info(s"Transaction coordinator epoch for $transactionalId has changed from ${txnMarker.coordinatorEpoch} to " +
+                s"${epochAndMetadata.coordinatorEpoch}; cancel sending transaction markers $txnMarker to the brokers")
+
+              txnMarkerChannelManager.removeMarkersForTxnId(transactionalId)
+              abortSending = true
+            } else {
+              txnMetadata synchronized {
+                for ((topicPartition: TopicPartition, error: Errors) <- errors) {
+                  error match {
+                    case Errors.NONE =>
+
+                      txnMetadata.removePartition(topicPartition)
+
+                    case Errors.CORRUPT_MESSAGE |
+                         Errors.MESSAGE_TOO_LARGE |
+                         Errors.RECORD_LIST_TOO_LARGE |
+                         Errors.INVALID_REQUIRED_ACKS => // these are all unexpected and fatal errors
+
+                      throw new IllegalStateException(s"Received fatal error ${error.exceptionName} while sending txn marker for $transactionalId")
+
+                    case Errors.UNKNOWN_TOPIC_OR_PARTITION |
+                         Errors.NOT_LEADER_FOR_PARTITION |
+                         Errors.NOT_ENOUGH_REPLICAS |
+                         Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND => // these are retriable errors
+
+                      info(s"Sending $transactionalId's transaction marker for partition $topicPartition has failed with error ${error.exceptionName}, retrying " +
+                        s"with current coordinator epoch ${epochAndMetadata.coordinatorEpoch}")
+
+                      retryPartitions += topicPartition
+
+                    case Errors.INVALID_PRODUCER_EPOCH |
+                         Errors.TRANSACTION_COORDINATOR_FENCED => // producer or coordinator epoch has changed, this txn can now be ignored
+
+                      info(s"Sending $transactionalId's transaction marker for partition $topicPartition has permanently failed with error ${error.exceptionName} " +
+                        s"with the current coordinator epoch ${epochAndMetadata.coordinatorEpoch}; cancel sending any more transaction markers $txnMarker to the brokers")
+
+                      txnMarkerChannelManager.removeMarkersForTxnId(transactionalId)
+                      abortSending = true
+
+                    case other =>
+                      throw new IllegalStateException(s"Unexpected error ${other.exceptionName} while sending txn marker for $transactionalId")
+                  }
+                }
+              }
+            }
+
+            if (!abortSending) {
+              if (retryPartitions.nonEmpty) {
+                // re-enqueue with possible new leaders of the partitions
+                txnMarkerChannelManager.addTxnMarkersToBrokerQueue(
+                  transactionalId,
+                  txnMarker.producerId(),
+                  txnMarker.producerEpoch(),
+                  txnMarker.transactionResult,
+                  txnMarker.coordinatorEpoch(),
+                  retryPartitions.toSet)
+              } else {
+                txnMarkerChannelManager.completeSendMarkersForTxnId(transactionalId)
               }
-            case Errors.UNKNOWN_TOPIC_OR_PARTITION | Errors.NOT_LEADER_FOR_PARTITION |
-                 Errors.NOT_ENOUGH_REPLICAS | Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND =>
-              retryPartitions += topicPartition
-            case _ =>
-              throw new IllegalStateException("Writing txn marker request failed permanently for pid " + txnMarker.producerId())
-          }
-
-          if (retryPartitions.nonEmpty) {
-            // re-enqueue with possible new leaders of the partitions
-            transactionMarkerChannel.addRequestToSend(
-              txnTopicPartition,
-              txnMarker.producerId(),
-              txnMarker.producerEpoch(),
-              txnMarker.transactionResult,
-              txnMarker.coordinatorEpoch(),
-              retryPartitions.toSet)
-          }
-          val completed = txnMarkerPurgatory.checkAndComplete(txnMarker.producerId())
-          trace(s"Competed $completed transactions for producerId ${txnMarker.producerId()}")
+            }
         }
       }
     }


[2/3] kafka git commit: KAFKA-5130: Refactor transaction coordinator's in-memory cache; plus fixes on transaction metadata synchronization

Posted by gu...@apache.org.
http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
index a81e47b..a76617e 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala
@@ -19,9 +19,9 @@ package kafka.coordinator.transaction
 import kafka.utils.nonthreadsafe
 import org.apache.kafka.common.TopicPartition
 
-import scala.collection.mutable
+import scala.collection.{immutable, mutable}
 
-private[coordinator] sealed trait TransactionState { def byte: Byte }
+private[transaction] sealed trait TransactionState { def byte: Byte }
 
 /**
  * Transaction has not existed yet
@@ -29,7 +29,7 @@ private[coordinator] sealed trait TransactionState { def byte: Byte }
  * transition: received AddPartitionsToTxnRequest => Ongoing
  *             received AddOffsetsToTxnRequest => Ongoing
  */
-private[coordinator] case object Empty extends TransactionState { val byte: Byte = 0 }
+private[transaction] case object Empty extends TransactionState { val byte: Byte = 0 }
 
 /**
  * Transaction has started and ongoing
@@ -39,37 +39,37 @@ private[coordinator] case object Empty extends TransactionState { val byte: Byte
  *             received AddPartitionsToTxnRequest => Ongoing
  *             received AddOffsetsToTxnRequest => Ongoing
  */
-private[coordinator] case object Ongoing extends TransactionState { val byte: Byte = 1 }
+private[transaction] case object Ongoing extends TransactionState { val byte: Byte = 1 }
 
 /**
  * Group is preparing to commit
  *
  * transition: received acks from all partitions => CompleteCommit
  */
-private[coordinator] case object PrepareCommit extends TransactionState { val byte: Byte = 2}
+private[transaction] case object PrepareCommit extends TransactionState { val byte: Byte = 2}
 
 /**
  * Group is preparing to abort
  *
  * transition: received acks from all partitions => CompleteAbort
  */
-private[coordinator] case object PrepareAbort extends TransactionState { val byte: Byte = 3 }
+private[transaction] case object PrepareAbort extends TransactionState { val byte: Byte = 3 }
 
 /**
  * Group has completed commit
  *
  * Will soon be removed from the ongoing transaction cache
  */
-private[coordinator] case object CompleteCommit extends TransactionState { val byte: Byte = 4 }
+private[transaction] case object CompleteCommit extends TransactionState { val byte: Byte = 4 }
 
 /**
  * Group has completed abort
  *
  * Will soon be removed from the ongoing transaction cache
  */
-private[coordinator] case object CompleteAbort extends TransactionState { val byte: Byte = 5 }
+private[transaction] case object CompleteAbort extends TransactionState { val byte: Byte = 5 }
 
-private[coordinator] object TransactionMetadata {
+private[transaction] object TransactionMetadata {
   def apply(pid: Long, epoch: Short, txnTimeoutMs: Int, timestamp: Long) = new TransactionMetadata(pid, epoch, txnTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
 
   def apply(pid: Long, epoch: Short, txnTimeoutMs: Int, state: TransactionState, timestamp: Long) = new TransactionMetadata(pid, epoch, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
@@ -89,7 +89,7 @@ private[coordinator] object TransactionMetadata {
   def isValidTransition(oldState: TransactionState, newState: TransactionState): Boolean = TransactionMetadata.validPreviousStates(newState).contains(oldState)
 
   private val validPreviousStates: Map[TransactionState, Set[TransactionState]] =
-    Map(Empty -> Set(),
+    Map(Empty -> Set(Empty, CompleteCommit, CompleteAbort),
       Ongoing -> Set(Ongoing, Empty, CompleteCommit, CompleteAbort),
       PrepareCommit -> Set(Ongoing),
       PrepareAbort -> Set(Ongoing),
@@ -97,24 +97,33 @@ private[coordinator] object TransactionMetadata {
       CompleteAbort -> Set(PrepareAbort))
 }
 
+// this is a immutable object representing the target transition of the transaction metadata
+private[transaction] case class TransactionMetadataTransition(producerId: Long,
+                                                              producerEpoch: Short,
+                                                              txnTimeoutMs: Int,
+                                                              txnState: TransactionState,
+                                                              topicPartitions: immutable.Set[TopicPartition],
+                                                              txnStartTimestamp: Long,
+                                                              txnLastUpdateTimestamp: Long)
+
 /**
   *
-  * @param pid                   producer id
+  * @param producerId            producer id
   * @param producerEpoch         current epoch of the producer
   * @param txnTimeoutMs          timeout to be used to abort long running transactions
-  * @param state                 the current state of the transaction
-  * @param topicPartitions       set of partitions that are part of this transaction
-  * @param transactionStartTime  time the transaction was started, i.e., when first partition is added
-  * @param lastUpdateTimestamp   updated when any operation updates the TransactionMetadata. To be used for expiration
+  * @param state                 current state of the transaction
+  * @param topicPartitions       current set of partitions that are part of this transaction
+  * @param txnStartTimestamp     time the transaction was started, i.e., when first partition is added
+  * @param txnLastUpdateTimestamp   updated when any operation updates the TransactionMetadata. To be used for expiration
   */
 @nonthreadsafe
-private[coordinator] class TransactionMetadata(val pid: Long,
+private[transaction] class TransactionMetadata(val producerId: Long,
                                                var producerEpoch: Short,
                                                var txnTimeoutMs: Int,
                                                var state: TransactionState,
                                                val topicPartitions: mutable.Set[TopicPartition],
-                                               var transactionStartTime: Long = -1,
-                                               var lastUpdateTimestamp: Long) {
+                                               var txnStartTimestamp: Long = -1,
+                                               var txnLastUpdateTimestamp: Long) {
 
   // pending state is used to indicate the state that this transaction is going to
   // transit to, and for blocking future attempts to transit it again if it is not legal;
@@ -125,50 +134,171 @@ private[coordinator] class TransactionMetadata(val pid: Long,
     topicPartitions ++= partitions
   }
 
-  def prepareTransitionTo(newState: TransactionState): Boolean = {
+  def removePartition(topicPartition: TopicPartition): Unit = {
+    if (pendingState.isDefined || (state != PrepareCommit && state != PrepareAbort))
+      throw new IllegalStateException(s"Transation metadata's current state is $state, and its pending state is $state " +
+        s"while trying to remove partitions whose txn marker has been sent, this is not expected")
+
+    topicPartitions -= topicPartition
+  }
+
+  def prepareNoTransit(): TransactionMetadataTransition =
+    // do not call transitTo as it will set the pending state
+    TransactionMetadataTransition(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
+
+  def prepareIncrementProducerEpoch(newTxnTimeoutMs: Int,
+                                    updateTimestamp: Long): TransactionMetadataTransition = {
+
+    prepareTransitionTo(Empty, (producerEpoch + 1).toShort, newTxnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp)
+  }
+
+  def prepareNewPid(updateTimestamp: Long): TransactionMetadataTransition = {
+
+    prepareTransitionTo(Empty, producerEpoch, txnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp)
+  }
+
+  def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition],
+                           updateTimestamp: Long): TransactionMetadataTransition = {
+
+    if (state == Empty || state == CompleteCommit || state == CompleteAbort) {
+      prepareTransitionTo(Ongoing, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet, updateTimestamp, updateTimestamp)
+    } else {
+      prepareTransitionTo(Ongoing, producerEpoch, txnTimeoutMs, (topicPartitions ++ addedTopicPartitions).toSet, txnStartTimestamp, updateTimestamp)
+    }
+  }
+
+  def prepareAbortOrCommit(newState: TransactionState,
+                           updateTimestamp: Long): TransactionMetadataTransition = {
+
+    prepareTransitionTo(newState, producerEpoch, txnTimeoutMs, topicPartitions.toSet, txnStartTimestamp, updateTimestamp)
+  }
+
+  def prepareComplete(updateTimestamp: Long): TransactionMetadataTransition = {
+    val newState = if (state == PrepareCommit) CompleteCommit else CompleteAbort
+    prepareTransitionTo(newState, producerEpoch, txnTimeoutMs, topicPartitions.toSet, txnStartTimestamp, updateTimestamp)
+  }
+
+  // visible for testing only
+  def copy(): TransactionMetadata = {
+    val cloned = new TransactionMetadata(producerId, producerEpoch, txnTimeoutMs, state,
+      mutable.Set.empty ++ topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp)
+    cloned.pendingState = pendingState
+
+    cloned
+  }
+
+  private def prepareTransitionTo(newState: TransactionState,
+                                  newEpoch: Short,
+                                  newTxnTimeoutMs: Int,
+                                  newTopicPartitions: immutable.Set[TopicPartition],
+                                  newTxnStartTimestamp: Long,
+                                  updateTimestamp: Long): TransactionMetadataTransition = {
     if (pendingState.isDefined)
-      throw new IllegalStateException(s"Preparing transaction state transition to $newState while it already a pending state ${pendingState.get}")
+      throw new IllegalStateException(s"Preparing transaction state transition to $newState " +
+        s"while it already a pending state ${pendingState.get}")
 
     // check that the new state transition is valid and update the pending state if necessary
     if (TransactionMetadata.validPreviousStates(newState).contains(state)) {
       pendingState = Some(newState)
-      true
+
+      TransactionMetadataTransition(producerId, newEpoch, newTxnTimeoutMs, newState, newTopicPartitions, newTxnStartTimestamp, updateTimestamp)
     } else {
-      false
+      throw new IllegalStateException(s"Preparing transaction state transition to $newState failed since the target state" +
+        s" $newState is not a valid previous state of the current state $state")
     }
   }
 
-  def completeTransitionTo(newState: TransactionState): Boolean = {
+  def completeTransitionTo(newMetadata: TransactionMetadataTransition): Unit = {
+    // metadata transition is valid only if all the following conditions are met:
+    //
+    // 1. the new state is already indicated in the pending state.
+    // 2. the pid is the same (i.e. this field should never be changed)
+    // 3. the epoch should be either the same value or old value + 1.
+    // 4. the last update time is no smaller than the old value.
+    // 4. the old partitions set is a subset of the new partitions set.
+    //
+    // plus, we should only try to update the metadata after the corresponding log entry has been successfully written and replicated (see TransactionStateManager#appendTransactionToLog)
+    //
+    // if valid, transition is done via overwriting the whole object to ensure synchronization
+
     val toState = pendingState.getOrElse(throw new IllegalStateException("Completing transaction state transition while it does not have a pending state"))
-    if (toState != newState) {
-      false
+
+    if (toState != newMetadata.txnState ||
+      producerId != newMetadata.producerId ||
+      txnLastUpdateTimestamp > newMetadata.txnLastUpdateTimestamp) {
+
+      throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata state")
     } else {
+      val updated = toState match {
+        case Empty => // from initPid
+          if (producerEpoch > newMetadata.producerEpoch ||
+            producerEpoch < newMetadata.producerEpoch - 1 ||
+            newMetadata.topicPartitions.nonEmpty ||
+            newMetadata.txnStartTimestamp != -1) {
+
+            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+          } else {
+            txnTimeoutMs = newMetadata.txnTimeoutMs
+            producerEpoch = newMetadata.producerEpoch
+          }
+
+        case Ongoing => // from addPartitions
+          if (producerEpoch != newMetadata.producerEpoch ||
+            !topicPartitions.subsetOf(newMetadata.topicPartitions) ||
+            txnTimeoutMs != newMetadata.txnTimeoutMs ||
+            txnStartTimestamp > newMetadata.txnStartTimestamp) {
+
+            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+          } else {
+            txnStartTimestamp = newMetadata.txnStartTimestamp
+            addPartitions(newMetadata.topicPartitions)
+          }
+
+        case PrepareAbort | PrepareCommit => // from endTxn
+          if (producerEpoch != newMetadata.producerEpoch ||
+            !topicPartitions.toSet.equals(newMetadata.topicPartitions) ||
+            txnTimeoutMs != newMetadata.txnTimeoutMs ||
+            txnStartTimestamp != newMetadata.txnStartTimestamp) {
+
+            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+          }
+
+        case CompleteAbort | CompleteCommit => // from write markers
+          if (producerEpoch != newMetadata.producerEpoch ||
+            txnTimeoutMs != newMetadata.txnTimeoutMs ||
+            newMetadata.txnStartTimestamp == -1) {
+
+            throw new IllegalStateException("Completing transaction state transition failed due to unexpected metadata")
+          } else {
+            txnStartTimestamp = newMetadata.txnStartTimestamp
+            topicPartitions.clear()
+          }
+      }
+
+      txnLastUpdateTimestamp = newMetadata.txnLastUpdateTimestamp
       pendingState = None
       state = toState
-      true
     }
   }
 
-  def copy(): TransactionMetadata =
-    new TransactionMetadata(pid, producerEpoch, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition] ++ topicPartitions, transactionStartTime, lastUpdateTimestamp)
+  def pendingTransitionInProgress: Boolean = pendingState.isDefined
 
-  override def toString = s"TransactionMetadata($pendingState, $pid, $producerEpoch, $txnTimeoutMs, $state, $topicPartitions, $transactionStartTime, $lastUpdateTimestamp)"
+  override def toString = s"TransactionMetadata($pendingState, $producerId, $producerEpoch, $txnTimeoutMs, $state, $topicPartitions, $txnStartTimestamp, $txnLastUpdateTimestamp)"
 
   override def equals(that: Any): Boolean = that match {
     case other: TransactionMetadata =>
-      pid == other.pid &&
+      producerId == other.producerId &&
       producerEpoch == other.producerEpoch &&
       txnTimeoutMs == other.txnTimeoutMs &&
       state.equals(other.state) &&
       topicPartitions.equals(other.topicPartitions) &&
-      transactionStartTime == other.transactionStartTime &&
-      lastUpdateTimestamp == other.lastUpdateTimestamp
+      txnStartTimestamp == other.txnStartTimestamp &&
+      txnLastUpdateTimestamp == other.txnLastUpdateTimestamp
     case _ => false
   }
 
-
   override def hashCode(): Int = {
-    val state = Seq(pid, txnTimeoutMs, topicPartitions, lastUpdateTimestamp)
-    state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
+    val fields = Seq(producerId, producerEpoch, txnTimeoutMs, state, topicPartitions, txnStartTimestamp, txnLastUpdateTimestamp)
+    fields.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
   }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index f5dc3c0..7a03fc3 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -33,17 +33,19 @@ import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.{FileRecords, MemoryRecords, SimpleRecord}
 import org.apache.kafka.common.requests.IsolationLevel
 import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
+import org.apache.kafka.common.requests.TransactionResult
 import org.apache.kafka.common.utils.{Time, Utils}
 
 import scala.collection.mutable
 import scala.collection.JavaConverters._
 
 
-object TransactionManager {
+object TransactionStateManager {
   // default transaction management config values
-  val DefaultTransactionalIdExpirationMs = TimeUnit.DAYS.toMillis(7).toInt
-  val DefaultTransactionsMaxTimeoutMs = TimeUnit.MINUTES.toMillis(15).toInt
-  val DefaultRemoveExpiredTransactionsIntervalMs = TimeUnit.MINUTES.toMillis(1).toInt
+  // TODO: this needs to be replaces by the config values
+  val DefaultTransactionsMaxTimeoutMs: Int = TimeUnit.MINUTES.toMillis(15).toInt
+  val DefaultTransactionalIdExpirationMs: Int = TimeUnit.DAYS.toMillis(7).toInt
+  val DefaultRemoveExpiredTransactionsIntervalMs: Int = TimeUnit.MINUTES.toMillis(1).toInt
 }
 
 /**
@@ -62,7 +64,7 @@ class TransactionStateManager(brokerId: Int,
 
   this.logIdent = "[Transaction Log Manager " + brokerId + "]: "
 
-  type WriteTxnMarkers = WriteTxnMarkerArgs => Unit
+  type SendTxnMarkersCallback = (String, Int, TransactionResult, TransactionMetadata, TransactionMetadataTransition) => Unit
 
   /** shutting down flag */
   private val shuttingDown = new AtomicBoolean(false)
@@ -70,40 +72,72 @@ class TransactionStateManager(brokerId: Int,
   /** lock protecting access to loading and owned partition sets */
   private val stateLock = new ReentrantLock()
 
-  /** partitions of transaction topic that are assigned to this manager, partition lock should be called BEFORE accessing this set */
-  private val ownedPartitions: mutable.Map[Int, Int] = mutable.Map()
-
   /** partitions of transaction topic that are being loaded, partition lock should be called BEFORE accessing this set */
   private val loadingPartitions: mutable.Set[Int] = mutable.Set()
 
-  /** transaction metadata cache indexed by transactional id */
-  private val transactionMetadataCache = new Pool[String, TransactionMetadata]
+  /** transaction metadata cache indexed by assigned transaction topic partition ids */
+  private val transactionMetadataCache: mutable.Map[Int, TxnMetadataCacheEntry] = mutable.Map()
 
   /** number of partitions for the transaction log topic */
   private val transactionTopicPartitionCount = getTransactionTopicPartitionCount
 
+  // this is best-effort expiration and hence not grabing the lock on metadata upon checking its state
+  // we will get the lock when actually trying to transit the transaction metadata to abort later.
+  def transactionsToExpire(): Iterable[TransactionalIdAndProducerIdEpoch] = {
+    val now = time.milliseconds()
+    transactionMetadataCache.flatMap { case (_, entry) =>
+        entry.metadataPerTransactionalId.filter { case (txnId, txnMetadata) =>
+          if (isCoordinatorLoadingInProgress(txnId) || txnMetadata.pendingTransitionInProgress) {
+            false
+          } else {
+            txnMetadata.state match {
+              case Ongoing =>
+                txnMetadata.txnStartTimestamp + txnMetadata.txnTimeoutMs < now
+              case _ => false
+            }
+          }
+        }.map { case (txnId, txnMetadata) =>
+          TransactionalIdAndProducerIdEpoch(txnId, txnMetadata.producerId, txnMetadata.producerEpoch)
+        }
+    }
+  }
+
   def enablePidExpiration() {
-    if (!scheduler.isStarted)
-      scheduler.startup()
     // TODO: add pid expiration logic
   }
 
   /**
    * Get the transaction metadata associated with the given transactional id, or null if not found
    */
-  def getTransactionState(transactionalId: String): Option[TransactionMetadata] = {
-    Option(transactionMetadataCache.get(transactionalId))
+  def getTransactionState(transactionalId: String): Option[CoordinatorEpochAndTxnMetadata] = {
+    val partitionId = partitionFor(transactionalId)
+
+    transactionMetadataCache.get(partitionId).flatMap { cacheEntry =>
+      cacheEntry.metadataPerTransactionalId.get(transactionalId) match {
+        case null => None
+        case txnMetadata => Some(CoordinatorEpochAndTxnMetadata(cacheEntry.coordinatorEpoch, txnMetadata))
+      }
+    }
   }
 
   /**
    * Add a new transaction metadata, or retrieve the metadata if it already exists with the associated transactional id
+   * along with the current coordinator epoch for that belonging transaction topic partition
    */
-  def addTransaction(transactionalId: String, txnMetadata: TransactionMetadata): TransactionMetadata = {
-    val currentTxnMetadata = transactionMetadataCache.putIfNotExists(transactionalId, txnMetadata)
-    if (currentTxnMetadata != null) {
-      currentTxnMetadata
-    } else {
-      txnMetadata
+  def addTransaction(transactionalId: String, txnMetadata: TransactionMetadata): CoordinatorEpochAndTxnMetadata = {
+    val partitionId = partitionFor(transactionalId)
+
+    transactionMetadataCache.get(partitionId) match {
+      case Some(txnMetadataCacheEntry) =>
+        val currentTxnMetadata = txnMetadataCacheEntry.metadataPerTransactionalId.putIfNotExists(transactionalId, txnMetadata)
+        if (currentTxnMetadata != null) {
+          CoordinatorEpochAndTxnMetadata(txnMetadataCacheEntry.coordinatorEpoch, currentTxnMetadata)
+        } else {
+          CoordinatorEpochAndTxnMetadata(txnMetadataCacheEntry.coordinatorEpoch, txnMetadata)
+        }
+
+      case None =>
+        throw new IllegalStateException(s"The metadata cache entry for txn partition $partitionId does not exist.")
     }
   }
 
@@ -129,13 +163,13 @@ class TransactionStateManager(brokerId: Int,
 
   def partitionFor(transactionalId: String): Int = Utils.abs(transactionalId.hashCode) % transactionTopicPartitionCount
 
-  def coordinatorEpochFor(transactionId: String): Option[Int] = inLock (stateLock) {
-    ownedPartitions.get(partitionFor(transactionId))
+  def isCoordinatorFor(txnTopicPartitionId: Int): Boolean = inLock(stateLock) {
+    transactionMetadataCache.contains(txnTopicPartitionId)
   }
 
   def isCoordinatorFor(transactionalId: String): Boolean = inLock(stateLock) {
     val partitionId = partitionFor(transactionalId)
-    ownedPartitions.contains(partitionId)
+    transactionMetadataCache.contains(partitionId)
   }
 
   def isCoordinatorLoadingInProgress(transactionalId: String): Boolean = inLock(stateLock) {
@@ -143,19 +177,6 @@ class TransactionStateManager(brokerId: Int,
     loadingPartitions.contains(partitionId)
   }
 
-
-  def transactionsToExpire(): Iterable[TransactionalIdAndMetadata] = {
-    val now = time.milliseconds()
-    transactionMetadataCache.filter { case (_, metadata) =>
-      metadata.state match {
-        case Ongoing =>
-          metadata.transactionStartTime + metadata.txnTimeoutMs < now
-        case _ => false
-      }
-    }.map {case (id, metadata) =>
-      TransactionalIdAndMetadata(id, metadata)
-    }
-  }
   /**
    * Gets the partition count of the transaction log topic from ZooKeeper.
    * If the topic does not exist, the default partition count is returned.
@@ -164,162 +185,159 @@ class TransactionStateManager(brokerId: Int,
     zkUtils.getTopicPartitionCount(Topic.TransactionStateTopicName).getOrElse(config.transactionLogNumPartitions)
   }
 
-  private def loadTransactionMetadata(topicPartition: TopicPartition, writeTxnMarkers: WriteTxnMarkers) {
-    def highWaterMark = replicaManager.getLogEndOffset(topicPartition).getOrElse(-1L)
+  private def loadTransactionMetadata(topicPartition: TopicPartition, coordinatorEpoch: Int): Pool[String, TransactionMetadata] =  {
+    def logEndOffset = replicaManager.getLogEndOffset(topicPartition).getOrElse(-1L)
 
     val startMs = time.milliseconds()
+    val loadedTransactions = new Pool[String, TransactionMetadata]
+
     replicaManager.getLog(topicPartition) match {
       case None =>
         warn(s"Attempted to load offsets and group metadata from $topicPartition, but found no log")
 
       case Some(log) =>
         lazy val buffer = ByteBuffer.allocate(config.transactionLogLoadBufferSize)
-        val loadedTransactions = mutable.Map.empty[String, TransactionMetadata]
-        val removedTransactionalIds = mutable.Set.empty[String]
 
         // loop breaks if leader changes at any time during the load, since getHighWatermark is -1
         var currOffset = log.logStartOffset
-        while (currOffset < highWaterMark
-                && loadingPartitions.contains(topicPartition.partition())
-                && !shuttingDown.get()) {
-          buffer.clear()
-          val fetchDataInfo = log.read(currOffset, config.transactionLogLoadBufferSize, maxOffset = None,
-            minOneMessage = true, isolationLevel = IsolationLevel.READ_UNCOMMITTED)
-          val memRecords = fetchDataInfo.records match {
-            case records: MemoryRecords => records
-            case fileRecords: FileRecords =>
-              buffer.clear()
-              val bufferRead = fileRecords.readInto(buffer, 0)
-              MemoryRecords.readableRecords(bufferRead)
-          }
-
-          memRecords.batches.asScala.foreach { batch =>
-            for (record <- batch.asScala) {
-              require(record.hasKey, "Transaction state log's key should not be null")
-              TransactionLog.readMessageKey(record.key) match {
-
-                case txnKey: TxnKey =>
-                  // load transaction metadata along with transaction state
-                  val transactionalId: String = txnKey.transactionalId
-                  if (!record.hasValue) {
-                    loadedTransactions.remove(transactionalId)
-                    removedTransactionalIds.add(transactionalId)
-                  } else {
-                    val txnMetadata = TransactionLog.readMessageValue(record.value)
-                    loadedTransactions.put(transactionalId, txnMetadata)
-                    removedTransactionalIds.remove(transactionalId)
-                  }
-
-                case unknownKey =>
-                  // TODO: Metrics
-                  throw new IllegalStateException(s"Unexpected message key $unknownKey while loading offsets and group metadata")
-              }
 
-              currOffset = batch.nextOffset
+        try {
+          while (currOffset < logEndOffset
+            && loadingPartitions.contains(topicPartition.partition())
+            && !shuttingDown.get()) {
+            val fetchDataInfo = log.read(currOffset, config.transactionLogLoadBufferSize, maxOffset = None,
+              minOneMessage = true, isolationLevel = IsolationLevel.READ_UNCOMMITTED)
+            val memRecords = fetchDataInfo.records match {
+              case records: MemoryRecords => records
+              case fileRecords: FileRecords =>
+                buffer.clear()
+                val bufferRead = fileRecords.readInto(buffer, 0)
+                MemoryRecords.readableRecords(bufferRead)
             }
-          }
 
-          loadedTransactions.foreach {
-            case (transactionalId, txnMetadata) =>
-              val currentTxnMetadata = addTransaction(transactionalId, txnMetadata)
-              if (!txnMetadata.eq(currentTxnMetadata)) {
-                // treat this as a fatal failure as this should never happen
-                fatal(s"Attempt to load $transactionalId's metadata $txnMetadata failed " +
-                  s"because there is already a different cached transaction metadata $currentTxnMetadata.")
+            memRecords.batches.asScala.foreach { batch =>
+              for (record <- batch.asScala) {
+                require(record.hasKey, "Transaction state log's key should not be null")
+                TransactionLog.readMessageKey(record.key) match {
+
+                  case txnKey: TxnKey =>
+                    // load transaction metadata along with transaction state
+                    val transactionalId: String = txnKey.transactionalId
+                    if (!record.hasValue) {
+                      loadedTransactions.remove(transactionalId)
+                    } else {
+                      val txnMetadata = TransactionLog.readMessageValue(record.value)
+                      loadedTransactions.put(transactionalId, txnMetadata)
+                    }
+
+                  case unknownKey =>
+                    // TODO: Metrics
+                    throw new IllegalStateException(s"Unexpected message key $unknownKey while loading offsets and group metadata")
+                }
 
-                throw new KafkaException("Loading transaction topic partition failed.")
-              }
-              // if state is PrepareCommit or PrepareAbort we need to complete the transaction
-              if (currentTxnMetadata.state == PrepareCommit || currentTxnMetadata.state == PrepareAbort) {
-                writeTxnMarkers(WriteTxnMarkerArgs(transactionalId,
-                  txnMetadata.pid,
-                  txnMetadata.producerEpoch,
-                  txnMetadata.state,
-                  txnMetadata,
-                  coordinatorEpochFor(transactionalId).get
-                ))
+                currOffset = batch.nextOffset
               }
-          }
-
-          removedTransactionalIds.foreach { transactionalId =>
-            if (transactionMetadataCache.contains(transactionalId)) {
-              // the cache already contains a transaction which should be removed,
-              // treat this as a fatal failure as this should never happen
-              fatal(s"Unexpected to see $transactionalId's metadata while " +
-                s"loading partition $topicPartition since its latest state is a tombstone")
-
-              throw new KafkaException("Loading transaction topic partition failed.")
             }
-          }
 
-          info(s"Finished loading ${loadedTransactions.size} transaction metadata from $topicPartition in ${time.milliseconds() - startMs} milliseconds")
+            info(s"Finished loading ${loadedTransactions.size} transaction metadata from $topicPartition in ${time.milliseconds() - startMs} milliseconds")
+          }
+        } catch {
+          case t: Throwable => error(s"Error loading transactions from transaction log $topicPartition", t)
         }
     }
+
+    loadedTransactions
+  }
+
+  /**
+    * Add a transaction topic partition into the cache
+    */
+  def addLoadedTransactionsToCache(txnTopicPartition: Int, coordinatorEpoch: Int, metadataPerTransactionalId: Pool[String, TransactionMetadata]): Unit = {
+    val txnMetadataCacheEntry = TxnMetadataCacheEntry(coordinatorEpoch, metadataPerTransactionalId)
+    val currentTxnMetadataCacheEntry = transactionMetadataCache.put(txnTopicPartition, txnMetadataCacheEntry)
+
+    if (currentTxnMetadataCacheEntry.isDefined) {
+      val coordinatorEpoch = currentTxnMetadataCacheEntry.get.coordinatorEpoch
+      val metadataPerTxnId = currentTxnMetadataCacheEntry.get.metadataPerTransactionalId
+      info(s"The metadata cache for txn partition $txnTopicPartition has already exist with epoch $coordinatorEpoch " +
+        s"and ${metadataPerTxnId.size} entries while trying to add to it; " +
+        s"it is likely that another process for loading from the transaction log has just executed earlier before")
+
+      throw new IllegalStateException(s"The metadata cache entry for txn partition $txnTopicPartition has already exist while trying to add to it.")
+    }
   }
 
   /**
    * When this broker becomes a leader for a transaction log partition, load this partition and
    * populate the transaction metadata cache with the transactional ids.
    */
-  def loadTransactionsForPartition(partition: Int, coordinatorEpoch: Int, writeTxnMarkers: WriteTxnMarkers) {
+  def loadTransactionsForTxnTopicPartition(partitionId: Int, coordinatorEpoch: Int, sendTxnMarkers: SendTxnMarkersCallback) {
     validateTransactionTopicPartitionCountIsStable()
 
-    val topicPartition = new TopicPartition(Topic.TransactionStateTopicName, partition)
+    val topicPartition = new TopicPartition(Topic.TransactionStateTopicName, partitionId)
 
     inLock(stateLock) {
-      ownedPartitions.put(partition, coordinatorEpoch)
-      loadingPartitions.add(partition)
+      loadingPartitions.add(partitionId)
     }
 
     def loadTransactions() {
       info(s"Loading transaction metadata from $topicPartition")
-      try {
-        loadTransactionMetadata(topicPartition, writeTxnMarkers)
-      } catch {
-        case t: Throwable => error(s"Error loading transactions from transaction log $topicPartition", t)
-      } finally {
-        inLock(stateLock) {
-          loadingPartitions.remove(partition)
-        }
+      val loadedTransactions = loadTransactionMetadata(topicPartition, coordinatorEpoch)
+
+      loadedTransactions.foreach {
+        case (transactionalId, txnMetadata) =>
+          val result = txnMetadata synchronized {
+            // if state is PrepareCommit or PrepareAbort we need to complete the transaction
+            txnMetadata.state match {
+              case PrepareAbort =>
+                Some(TransactionResult.ABORT, txnMetadata.prepareComplete(time.milliseconds()))
+              case PrepareCommit =>
+                Some(TransactionResult.COMMIT, txnMetadata.prepareComplete(time.milliseconds()))
+              case _ =>
+                // nothing need to be done
+                None
+            }
+          }
+
+          result.foreach { case (command, newMetadata) =>
+            sendTxnMarkers(transactionalId, coordinatorEpoch, command, txnMetadata, newMetadata)
+          }
+      }
+
+      inLock(stateLock) {
+        addLoadedTransactionsToCache(topicPartition.partition, coordinatorEpoch, loadedTransactions)
+        loadingPartitions.remove(partitionId)
       }
     }
 
-    scheduler.schedule(topicPartition.toString, loadTransactions _)
+    scheduler.schedule(s"load-txns-for-partition-$topicPartition", loadTransactions _)
   }
 
   /**
    * When this broker becomes a follower for a transaction log partition, clear out the cache for corresponding transactional ids
    * that belong to that partition.
    */
-  def removeTransactionsForPartition(partition: Int) {
+  def removeTransactionsForTxnTopicPartition(partitionId: Int) {
     validateTransactionTopicPartitionCountIsStable()
 
-    val topicPartition = new TopicPartition(Topic.TransactionStateTopicName, partition)
-
-    inLock(stateLock) {
-      ownedPartitions.remove(partition)
-      loadingPartitions.remove(partition)
-    }
+    val topicPartition = new TopicPartition(Topic.TransactionStateTopicName, partitionId)
 
     def removeTransactions() {
-      var numTxnsRemoved = 0
-
       inLock(stateLock) {
-        for (transactionalId <- transactionMetadataCache.keys) {
-          if (partitionFor(transactionalId) == partition) {
-            // we do not need to worry about whether the transactional id has any ongoing transaction or not since
-            // the new leader will handle it
-            transactionMetadataCache.remove(transactionalId)
-            numTxnsRemoved += 1
-          }
+        transactionMetadataCache.remove(partitionId) match {
+          case Some(txnMetadataCacheEntry) =>
+            info(s"Removed ${txnMetadataCacheEntry.metadataPerTransactionalId.size} cached transaction metadata for $topicPartition on follower transition")
+
+          case None =>
+            info(s"Trying to remove cached transaction metadata for $topicPartition on follower transition but there is no entries remaining; " +
+              s"it is likely that another process for removing the cached entries has just executed earlier before")
         }
 
-        if (numTxnsRemoved > 0)
-          info(s"Removed $numTxnsRemoved cached transaction metadata for $topicPartition on follower transition")
+        loadingPartitions.remove(partitionId)
       }
     }
 
-    scheduler.schedule(topicPartition.toString, removeTransactions _)
+    scheduler.schedule(s"remove-txns-for-partition-$topicPartition", removeTransactions _)
   }
 
   private def validateTransactionTopicPartitionCountIsStable(): Unit = {
@@ -330,12 +348,13 @@ class TransactionStateManager(brokerId: Int,
 
   // TODO: check broker message format and error if < V2
   def appendTransactionToLog(transactionalId: String,
-                             txnMetadata: TransactionMetadata,
-                             responseCallback: Errors => Unit) {
+                             coordinatorEpoch: Int,
+                             newMetadata: TransactionMetadataTransition,
+                             responseCallback: Errors => Unit): Unit = {
 
     // generate the message for this transaction metadata
     val keyBytes = TransactionLog.keyToBytes(transactionalId)
-    val valueBytes = TransactionLog.valueToBytes(txnMetadata)
+    val valueBytes = TransactionLog.valueToBytes(newMetadata)
     val timestamp = time.milliseconds()
 
     val records = MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, new SimpleRecord(timestamp, keyBytes, valueBytes))
@@ -355,7 +374,7 @@ class TransactionStateManager(brokerId: Int,
       var responseError = if (status.error == Errors.NONE) {
         Errors.NONE
       } else {
-        debug(s"Transaction state update $txnMetadata for $transactionalId failed when appending to log " +
+        debug(s"Transaction state update $newMetadata for $transactionalId failed when appending to log " +
           s"due to ${status.error.exceptionName}")
 
         // transform the log append error code to the corresponding coordinator error code
@@ -365,14 +384,14 @@ class TransactionStateManager(brokerId: Int,
                | Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND
                | Errors.REQUEST_TIMED_OUT => // note that for timed out request we return NOT_AVAILABLE error code to let client retry
 
-            debug(s"Appending transaction message $txnMetadata for $transactionalId failed due to " +
+            info(s"Appending transaction message $newMetadata for $transactionalId failed due to " +
               s"${status.error.exceptionName}, returning ${Errors.COORDINATOR_NOT_AVAILABLE} to the client")
 
             Errors.COORDINATOR_NOT_AVAILABLE
 
           case Errors.NOT_LEADER_FOR_PARTITION =>
 
-            debug(s"Appending transaction message $txnMetadata for $transactionalId failed due to " +
+            info(s"Appending transaction message $newMetadata for $transactionalId failed due to " +
               s"${status.error.exceptionName}, returning ${Errors.NOT_COORDINATOR} to the client")
 
             Errors.NOT_COORDINATOR
@@ -380,13 +399,13 @@ class TransactionStateManager(brokerId: Int,
           case Errors.MESSAGE_TOO_LARGE
                | Errors.RECORD_LIST_TOO_LARGE =>
 
-            error(s"Appending transaction message $txnMetadata for $transactionalId failed due to " +
+            error(s"Appending transaction message $newMetadata for $transactionalId failed due to " +
               s"${status.error.exceptionName}, returning UNKNOWN error code to the client")
 
             Errors.UNKNOWN
 
           case other =>
-            error(s"Appending metadata message $txnMetadata for $transactionalId failed due to " +
+            error(s"Appending metadata message $newMetadata for $transactionalId failed due to " +
               s"unexpected error: ${status.error.message}")
 
             other
@@ -394,44 +413,42 @@ class TransactionStateManager(brokerId: Int,
       }
 
       if (responseError == Errors.NONE) {
-        def completeStateTransition(metadata: TransactionMetadata, newState: TransactionState): Boolean = {
-          // there is no transition in this case
-          if (metadata.state == Empty && newState == Empty)
-            true
-          else
-            metadata.completeTransitionTo(txnMetadata.state)
-        }
         // now try to update the cache: we need to update the status in-place instead of
         // overwriting the whole object to ensure synchronization
-          getTransactionState(transactionalId) match {
-            case Some(metadata) =>
-              metadata synchronized {
-                if (metadata.pid == txnMetadata.pid &&
-                  metadata.producerEpoch == txnMetadata.producerEpoch &&
-                  metadata.txnTimeoutMs == txnMetadata.txnTimeoutMs &&
-                  completeStateTransition(metadata, txnMetadata.state)) {
-                  // only topic-partition lists could possibly change (state should have transited in the above condition)
-                  metadata.addPartitions(txnMetadata.topicPartitions.toSet)
-                } else {
-                  throw new IllegalStateException(s"Completing transaction state transition to $txnMetadata while its current state is $metadata.")
-                }
+        getTransactionState(transactionalId) match {
+          case Some(epochAndMetadata) =>
+            val metadata = epochAndMetadata.transactionMetadata
+
+            metadata synchronized {
+              if (epochAndMetadata.coordinatorEpoch != coordinatorEpoch) {
+                // the cache may have been changed due to txn topic partition emigration and immigration,
+                // in this case directly return NOT_COORDINATOR to client and let it to re-discover the transaction coordinator
+                info(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId failed after the transaction message " +
+                  s"has been appended to the log. The cached coordinator epoch has changed to ${epochAndMetadata.coordinatorEpoch}")
+
+                responseError = Errors.NOT_COORDINATOR
+              } else {
+                metadata.completeTransitionTo(newMetadata)
+
+                debug(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId succeeded")
               }
+            }
 
-            case None =>
-              // this transactional id no longer exists, maybe the corresponding partition has already been migrated out.
-              // return NOT_COORDINATOR to let the client retry
-              debug(s"Updating $transactionalId's transaction state to $txnMetadata for $transactionalId failed after the transaction message " +
-                s"has been appended to the log. The partition for $transactionalId may have migrated as the metadata is no longer in the cache")
+          case None =>
+            // this transactional id no longer exists, maybe the corresponding partition has already been migrated out.
+            // return NOT_COORDINATOR to let the client re-discover the transaction coordinator
+            info(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId failed after the transaction message " +
+              s"has been appended to the log. The partition ${partitionFor(transactionalId)} may have migrated as the metadata is no longer in the cache")
 
-              responseError = Errors.NOT_COORDINATOR
-          }
+            responseError = Errors.NOT_COORDINATOR
+        }
       }
 
       responseCallback(responseError)
     }
 
     replicaManager.appendRecords(
-      txnMetadata.txnTimeoutMs.toLong,
+      newMetadata.txnTimeoutMs.toLong,
       TransactionLog.EnforcedRequiredAcks,
       internalTopicsAllowed = true,
       isFromClient = false,
@@ -441,25 +458,25 @@ class TransactionStateManager(brokerId: Int,
 
   def shutdown() {
     shuttingDown.set(true)
-    if (scheduler.isStarted)
-      scheduler.shutdown()
-
-    transactionMetadataCache.clear()
-
-    ownedPartitions.clear()
     loadingPartitions.clear()
+    transactionMetadataCache.clear()
 
     info("Shutdown complete")
   }
 }
 
-private[transaction] case class TransactionConfig(transactionalIdExpirationMs: Int = TransactionManager.DefaultTransactionalIdExpirationMs,
-                                                  transactionMaxTimeoutMs: Int = TransactionManager.DefaultTransactionsMaxTimeoutMs,
+
+private[transaction] case class TxnMetadataCacheEntry(coordinatorEpoch: Int, metadataPerTransactionalId: Pool[String, TransactionMetadata])
+
+private[transaction] case class CoordinatorEpochAndTxnMetadata(coordinatorEpoch: Int, transactionMetadata: TransactionMetadata)
+
+private[transaction] case class TransactionConfig(transactionalIdExpirationMs: Int = TransactionStateManager.DefaultTransactionalIdExpirationMs,
+                                                  transactionMaxTimeoutMs: Int = TransactionStateManager.DefaultTransactionsMaxTimeoutMs,
                                                   transactionLogNumPartitions: Int = TransactionLog.DefaultNumPartitions,
                                                   transactionLogReplicationFactor: Short = TransactionLog.DefaultReplicationFactor,
                                                   transactionLogSegmentBytes: Int = TransactionLog.DefaultSegmentBytes,
                                                   transactionLogLoadBufferSize: Int = TransactionLog.DefaultLoadBufferSize,
                                                   transactionLogMinInsyncReplicas: Int = TransactionLog.DefaultMinInSyncReplicas,
-                                                  removeExpiredTransactionsIntervalMs: Int = TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs)
+                                                  removeExpiredTransactionsIntervalMs: Int = TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
 
-case class TransactionalIdAndMetadata(transactionalId: String, metadata: TransactionMetadata)
+case class TransactionalIdAndProducerIdEpoch(transactionalId: String, producerId: Long, producerEpoch: Short)

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/server/DelayedOperation.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/DelayedOperation.scala b/core/src/main/scala/kafka/server/DelayedOperation.scala
index c0efc53..6401600 100644
--- a/core/src/main/scala/kafka/server/DelayedOperation.scala
+++ b/core/src/main/scala/kafka/server/DelayedOperation.scala
@@ -118,7 +118,8 @@ object DelayedOperationPurgatory {
 
   def apply[T <: DelayedOperation](purgatoryName: String,
                                    brokerId: Int = 0,
-                                   purgeInterval: Int = 1000): DelayedOperationPurgatory[T] = {
+                                   purgeInterval: Int = 1000,
+                                   reaperEnabled: Boolean = true): DelayedOperationPurgatory[T] = {
     val timer = new SystemTimer(purgatoryName)
     new DelayedOperationPurgatory[T](purgatoryName, timer, brokerId, purgeInterval)
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/server/KafkaConfig.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala
index 76f6380..690d167 100755
--- a/core/src/main/scala/kafka/server/KafkaConfig.scala
+++ b/core/src/main/scala/kafka/server/KafkaConfig.scala
@@ -23,7 +23,7 @@ import kafka.api.{ApiVersion, KAFKA_0_10_0_IV1}
 import kafka.cluster.EndPoint
 import kafka.consumer.ConsumerConfig
 import kafka.coordinator.group.OffsetConfig
-import kafka.coordinator.transaction.{TransactionLog, TransactionManager}
+import kafka.coordinator.transaction.{TransactionLog, TransactionStateManager}
 import kafka.message.{BrokerCompressionCodec, CompressionCodec, Message, MessageSet}
 import kafka.utils.CoreUtils
 import org.apache.kafka.clients.CommonClientConfigs
@@ -158,14 +158,14 @@ object Defaults {
   val OffsetCommitRequiredAcks = OffsetConfig.DefaultOffsetCommitRequiredAcks
 
   /** ********* Transaction management configuration ***********/
-  val TransactionalIdExpirationMs = TransactionManager.DefaultTransactionalIdExpirationMs
-  val TransactionsMaxTimeoutMs = TransactionManager.DefaultTransactionsMaxTimeoutMs
+  val TransactionalIdExpirationMs = TransactionStateManager.DefaultTransactionalIdExpirationMs
+  val TransactionsMaxTimeoutMs = TransactionStateManager.DefaultTransactionsMaxTimeoutMs
   val TransactionsTopicMinISR = TransactionLog.DefaultMinInSyncReplicas
   val TransactionsLoadBufferSize = TransactionLog.DefaultLoadBufferSize
   val TransactionsTopicReplicationFactor = TransactionLog.DefaultReplicationFactor
   val TransactionsTopicPartitions = TransactionLog.DefaultNumPartitions
   val TransactionsTopicSegmentBytes = TransactionLog.DefaultSegmentBytes
-  val TransactionsExpiredTransactionCleanupIntervalMS = TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs
+  val TransactionsExpiredTransactionCleanupIntervalMS = TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs
 
   /** ********* Quota Configuration ***********/
   val ProducerQuotaBytesPerSecondDefault = ClientQuotaManagerConfig.QuotaBytesPerSecondDefault

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/server/MetadataCache.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/MetadataCache.scala b/core/src/main/scala/kafka/server/MetadataCache.scala
index 1b334ac..2e4c19a 100755
--- a/core/src/main/scala/kafka/server/MetadataCache.scala
+++ b/core/src/main/scala/kafka/server/MetadataCache.scala
@@ -159,6 +159,24 @@ class MetadataCache(brokerId: Int) extends Logging {
     }
   }
 
+  def getPartitionLeaderEndpoint(topic: String, partitionId: Int, listenerName: ListenerName): Option[Node] = {
+    inReadLock(partitionMetadataLock) {
+      cache.get(topic).flatMap(_.get(partitionId)) match {
+        case Some(partitionInfo) =>
+          val leaderId = partitionInfo.leaderIsrAndControllerEpoch.leaderAndIsr.leader
+          try {
+            getAliveEndpoint(leaderId, listenerName)
+          } catch {
+            case e: BrokerEndPointNotAvailableException =>
+              None
+          }
+
+        case None =>
+          None
+      }
+    }
+  }
+
   def getControllerId: Option[Int] = controllerId
 
   // This method returns the deleted TopicPartitions received from UpdateMetadataRequest

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/main/scala/kafka/server/ReplicaManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 99c1b45..9cd92f7 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -149,11 +149,11 @@ class ReplicaManager(val config: KafkaConfig,
   private val lastIsrPropagationMs = new AtomicLong(System.currentTimeMillis())
 
   val delayedProducePurgatory = DelayedOperationPurgatory[DelayedProduce](
-    purgatoryName = "Produce", localBrokerId, config.producerPurgatoryPurgeIntervalRequests)
+    purgatoryName = "Produce", brokerId = localBrokerId, purgeInterval = config.producerPurgatoryPurgeIntervalRequests)
   val delayedFetchPurgatory = DelayedOperationPurgatory[DelayedFetch](
-    purgatoryName = "Fetch", localBrokerId, config.fetchPurgatoryPurgeIntervalRequests)
+    purgatoryName = "Fetch", brokerId = localBrokerId, purgeInterval = config.fetchPurgatoryPurgeIntervalRequests)
   val delayedDeleteRecordsPurgatory = DelayedOperationPurgatory[DelayedDeleteRecords](
-    purgatoryName = "DeleteRecords", localBrokerId, config.deleteRecordsPurgatoryPurgeIntervalRequests)
+    purgatoryName = "DeleteRecords", brokerId = localBrokerId, purgeInterval = config.deleteRecordsPurgatoryPurgeIntervalRequests)
 
   val leaderCount = newGauge(
     "LeaderCount",

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorIntegrationTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorIntegrationTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorIntegrationTest.scala
index 20d1161..df23952 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorIntegrationTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorIntegrationTest.scala
@@ -55,6 +55,12 @@ class TransactionCoordinatorIntegrationTest extends KafkaServerTestHarness {
     val txnId = "txn"
     tc.handleInitPid(txnId, 900000, callback)
 
+    while(initPidResult == null) {
+      Utils.sleep(1)
+    }
+
+    Assert.assertEquals(Errors.NONE, initPidResult.error)
+
     @volatile var addPartitionErrors: Errors = null
     def addPartitionsCallback(errors: Errors): Unit = {
         addPartitionErrors = errors

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index a9f1bca..395bfb9 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -40,6 +40,7 @@ class TransactionCoordinatorTest {
   val capturedTxn: Capture[TransactionMetadata] = EasyMock.newCapture()
   val capturedErrorsCallback: Capture[Errors => Unit] = EasyMock.newCapture()
   val brokerId = 0
+  val coordinatorEpoch = 0
   private val transactionalId = "known"
   private val pid = 10
   private val epoch:Short = 1
@@ -50,11 +51,11 @@ class TransactionCoordinatorTest {
   private val scheduler = new MockScheduler(time)
 
   val coordinator: TransactionCoordinator = new TransactionCoordinator(brokerId,
+    scheduler,
     pidManager,
     transactionManager,
     transactionMarkerChannelManager,
     txnMarkerPurgatory,
-    scheduler,
     time)
 
   var result: InitPidResult = _
@@ -76,7 +77,6 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(EasyMock.eq(transactionalId)))
       .andReturn(true)
       .anyTimes()
-
     EasyMock.expect(transactionManager.isCoordinatorLoadingInProgress(EasyMock.anyString()))
       .andReturn(false)
       .anyTimes()
@@ -85,7 +85,6 @@ class TransactionCoordinatorTest {
       .anyTimes()
   }
 
-
   @Test
   def shouldAcceptInitPidAndReturnNextPidWhenTransactionalIdIsEmpty(): Unit = {
     mockPidManager()
@@ -111,28 +110,30 @@ class TransactionCoordinatorTest {
   @Test
   def shouldInitPidWithEpochZeroForNewTransactionalId(): Unit = {
     initPidGenericMocks(transactionalId)
-    EasyMock.expect(transactionManager.addTransaction(EasyMock.eq(transactionalId), EasyMock.capture(capturedTxn)))
-      .andAnswer(new IAnswer[TransactionMetadata] {
-        override def answer(): TransactionMetadata = {
-          capturedTxn.getValue
+
+    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
+      .andAnswer(new IAnswer[Option[CoordinatorEpochAndTxnMetadata]] {
+        override def answer(): Option[CoordinatorEpochAndTxnMetadata] = {
+          if (capturedTxn.hasCaptured)
+            Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue))
+          else
+            None
         }
       })
       .once()
-    EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId)))
-      .andAnswer(new IAnswer[Option[TransactionMetadata]] {
-        override def answer(): Option[TransactionMetadata] = {
-          if (capturedTxn.hasCaptured) {
-            Some(capturedTxn.getValue)
-          } else {
-            None
-          }
+
+    EasyMock.expect(transactionManager.addTransaction(EasyMock.eq(transactionalId), EasyMock.capture(capturedTxn)))
+      .andAnswer(new IAnswer[CoordinatorEpochAndTxnMetadata] {
+        override def answer(): CoordinatorEpochAndTxnMetadata = {
+          CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue)
         }
       })
       .once()
 
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
-      EasyMock.capture(capturedTxn),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.anyObject().asInstanceOf[TransactionMetadataTransition],
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {
@@ -143,7 +144,7 @@ class TransactionCoordinatorTest {
     EasyMock.replay(pidManager, transactionManager)
 
     coordinator.handleInitPid(transactionalId, txnTimeoutMs, initPidMockCallback)
-    assertEquals(InitPidResult(0L, 0, Errors.NONE), result)
+    assertEquals(InitPidResult(nextPid - 1, 0, Errors.NONE), result)
   }
 
   @Test
@@ -212,7 +213,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(0, 0, 0, state, mutable.Set.empty, 0, 0)))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(0, 0, 0, state, mutable.Set.empty, 0, 0))))
 
     EasyMock.replay(transactionManager)
 
@@ -225,7 +226,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(0, 10, 0, PrepareCommit, mutable.Set.empty, 0, 0)))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(0, 10, 0, PrepareCommit, mutable.Set.empty, 0, 0))))
 
     EasyMock.replay(transactionManager)
 
@@ -254,20 +255,23 @@ class TransactionCoordinatorTest {
   }
 
   def validateSuccessfulAddPartitions(previousState: TransactionState): Unit = {
+    val txnMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds())
+
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(0, 0, 0, previousState, mutable.Set.empty, 0, 0)))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
 
     EasyMock.expect(transactionManager.appendTransactionToLog(
       EasyMock.eq(transactionalId),
-      EasyMock.eq(new TransactionMetadata(0, 0, 0, Ongoing, partitions, if (previousState == Ongoing) 0 else time.milliseconds(), time.milliseconds())),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.anyObject().asInstanceOf[TransactionMetadataTransition],
       EasyMock.capture(capturedErrorsCallback)
     ))
 
     EasyMock.replay(transactionManager)
 
-    coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback)
+    coordinator.handleAddPartitionsToTransaction(transactionalId, pid, epoch, partitions, errorsCallback)
 
     EasyMock.verify(transactionManager)
   }
@@ -277,7 +281,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(0, 0, 0, Empty, partitions, 0, 0)))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(0, 0, 0, Empty, partitions, 0, 0))))
 
     EasyMock.replay(transactionManager)
 
@@ -304,7 +308,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(10, 0, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(10, 0, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
@@ -317,7 +321,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 0, TransactionResult.COMMIT, errorsCallback)
@@ -330,7 +334,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
@@ -343,7 +347,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.ABORT, errorsCallback)
@@ -356,7 +360,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
@@ -369,7 +373,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.ABORT, errorsCallback)
@@ -378,15 +382,15 @@ class TransactionCoordinatorTest {
   }
 
   @Test
-  def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = {
+  def shouldReturnConcurrentTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
-    assertEquals(Errors.INVALID_TXN_STATE, error)
+    assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
     EasyMock.verify(transactionManager)
   }
 
@@ -395,7 +399,7 @@ class TransactionCoordinatorTest {
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(new TransactionMetadata(pid, 1, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(pid, 1, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))
     EasyMock.replay(transactionManager)
 
     coordinator.handleEndTransaction(transactionalId, pid, 1, TransactionResult.COMMIT, errorsCallback)
@@ -425,29 +429,6 @@ class TransactionCoordinatorTest {
     EasyMock.verify(transactionManager)
   }
 
-
-  @Test
-  def shouldAppendCompleteAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(): Unit = {
-    mockComplete(PrepareAbort)
-
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.ABORT, errorsCallback)
-
-    EasyMock.verify(transactionManager)
-  }
-
-  @Test
-  def shouldAppendCompleteCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(): Unit = {
-    mockComplete(PrepareCommit)
-
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
-
-    EasyMock.verify(transactionManager)
-  }
-
   @Test
   def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(): Unit = {
     coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, errorsCallback)
@@ -506,18 +487,29 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): Unit = {
+    val txnMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, 0, 0)
+
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
+      .anyTimes()
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
 
-    val metadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(metadata))
-      .once()
-
-    mockComplete(PrepareAbort)
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
+      .anyTimes()
 
+    val originalMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, 0, 0)
+    EasyMock.expect(transactionManager.appendTransactionToLog(
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())),
+      EasyMock.capture(capturedErrorsCallback)))
+      .andAnswer(new IAnswer[Unit] {
+        override def answer(): Unit = {
+          capturedErrorsCallback.getValue.apply(Errors.NONE)
+        }
+      })
 
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
@@ -529,8 +521,8 @@ class TransactionCoordinatorTest {
 
   @Test
   def shouldRemoveTransactionsForPartitionOnEmigration(): Unit = {
-    EasyMock.expect(transactionManager.removeTransactionsForPartition(0))
-    EasyMock.expect(transactionMarkerChannelManager.removeStateForPartition(0))
+    EasyMock.expect(transactionManager.removeTransactionsForTxnTopicPartition(0))
+    EasyMock.expect(transactionMarkerChannelManager.removeMarkersForTxnTopicPartition(0))
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
     coordinator.handleTxnEmigration(0)
@@ -539,114 +531,22 @@ class TransactionCoordinatorTest {
   }
 
   @Test
-  def shouldRetryOnCommitWhenTxnMarkerRequestFailsWithErrorOtherThanNotCoordinator(): Unit = {
-    val prepareMetadata = mockPrepare(PrepareCommit, runCallback = true)
-
-    EasyMock.expect(transactionManager.coordinatorEpochFor(transactionalId))
-      .andReturn(Some(0))
-
-    EasyMock.expect(transactionMarkerChannelManager.addTxnMarkerRequest(
-      EasyMock.eq(0),
-      EasyMock.anyObject(),
-      EasyMock.anyInt(),
-      EasyMock.capture(capturedErrorsCallback)
-    )).andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        capturedErrorsCallback.getValue.apply(Errors.NETWORK_EXCEPTION)
-      }
-    }).andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        capturedErrorsCallback.getValue.apply(Errors.NONE)
-      }
-    })
-
-    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(prepareMetadata))
-      .once()
-
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
-
-    EasyMock.verify(transactionMarkerChannelManager)
-  }
-
-  @Test
-  def shouldNotRetryOnCommitWhenTxnMarkerRequestFailsWithNotCoordinator(): Unit = {
-    val prepareMetadata = mockPrepare(PrepareCommit, runCallback = true)
-
-    EasyMock.expect(transactionManager.coordinatorEpochFor(transactionalId))
-      .andReturn(Some(0))
-
-    EasyMock.expect(transactionMarkerChannelManager.addTxnMarkerRequest(
-      EasyMock.eq(0),
-      EasyMock.anyObject(),
-      EasyMock.anyInt(),
-      EasyMock.capture(capturedErrorsCallback)
-    )).andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        capturedErrorsCallback.getValue.apply(Errors.NOT_COORDINATOR)
-      }
-    })
-
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
-
-    EasyMock.verify(transactionMarkerChannelManager)
-  }
-
-  @Test
-  def shouldNotRetryOnCommitWhenAppendToLogFailsWithNotCoordinator(): Unit = {
-    mockComplete(PrepareCommit, Errors.NOT_COORDINATOR)
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
-
-    EasyMock.verify(transactionManager)
-  }
-
-  @Test
-  def shouldRetryOnCommitWhenAppendToLogFailsErrorsOtherThanNotCoordinator(): Unit = {
-    mockComplete(PrepareCommit, Errors.ILLEGAL_GENERATION)
-    EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-
-    coordinator.handleEndTransaction(transactionalId, pid, epoch, TransactionResult.COMMIT, errorsCallback)
-
-    EasyMock.verify(transactionManager)
-  }
-
-  @Test
   def shouldAbortExpiredTransactionsInOngoingState(): Unit = {
-    EasyMock.expect(transactionManager.transactionsToExpire())
-    .andReturn(List(TransactionalIdAndMetadata(transactionalId,
-      new TransactionMetadata(pid, epoch, 0, Ongoing, partitions, time.milliseconds(), time.milliseconds()))))
-
-    // should bump the epoch and append to the log
-    val metadata = new TransactionMetadata(pid, (epoch + 1).toShort, 0, Ongoing, partitions, time.milliseconds(), time.milliseconds())
-    EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
-      EasyMock.eq(metadata),
-      EasyMock.capture(capturedErrorsCallback)))
-    .andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        capturedErrorsCallback.getValue.apply(Errors.NONE)
-      }
-    }).once()
+    val txnMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
 
+    EasyMock.expect(transactionManager.transactionsToExpire())
+      .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, pid, epoch)))
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(metadata))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))
       .once()
 
-    // now should perform the rollback and append the state as PrepareAbort
-    val abortMetadata = metadata.copy()
-    abortMetadata.state = PrepareAbort
-    // need to allow for the time.sleep below
-    abortMetadata.lastUpdateTimestamp = time.milliseconds() + TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs
+    val newMetadata = txnMetadata.copy().prepareAbortOrCommit(PrepareAbort, time.milliseconds() + TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
 
     EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
-      EasyMock.eq(abortMetadata),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(newMetadata),
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {}
@@ -656,147 +556,135 @@ class TransactionCoordinatorTest {
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
 
     coordinator.startup(false)
-    time.sleep(TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs)
+    time.sleep(TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
     scheduler.tick()
     EasyMock.verify(transactionManager)
   }
 
   @Test
   def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = {
-    val metadata = new TransactionMetadata(pid, epoch, 0, Ongoing, partitions, time.milliseconds(), time.milliseconds())
-    metadata.prepareTransitionTo(PrepareCommit)
+    val metadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
+    metadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds())
 
     EasyMock.expect(transactionManager.transactionsToExpire())
-      .andReturn(List(TransactionalIdAndMetadata(transactionalId,
-        metadata)))
+      .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, pid, epoch)))
     
     EasyMock.replay(transactionManager, transactionMarkerChannelManager)
-    coordinator.startup(false)
 
-    time.sleep(TransactionManager.DefaultRemoveExpiredTransactionsIntervalMs)
+    coordinator.startup(false)
+    time.sleep(TransactionStateManager.DefaultRemoveExpiredTransactionsIntervalMs)
     scheduler.tick()
     EasyMock.verify(transactionManager)
-
   }
 
   private def validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(state: TransactionState) = {
-    val transactionId = "tid"
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionId))
+    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true).anyTimes()
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true).anyTimes()
 
     val metadata = new TransactionMetadata(0, 0, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0)
-    EasyMock.expect(transactionManager.getTransactionState(transactionId))
-      .andReturn(Some(metadata)).anyTimes()
+    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))).anyTimes()
 
     EasyMock.replay(transactionManager)
 
-    coordinator.handleInitPid(transactionId, 10, initPidMockCallback)
+    coordinator.handleInitPid(transactionalId, 10, initPidMockCallback)
 
     assertEquals(InitPidResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result)
   }
 
   private def validateIncrementEpochAndUpdateMetadata(state: TransactionState) = {
-    val transactionId = "tid"
-    EasyMock.expect(transactionManager.isCoordinatorFor(transactionId))
+    EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
     EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt()))
       .andReturn(true)
 
-    val metadata = new TransactionMetadata(0, 0, 0, state, mutable.Set.empty[TopicPartition], 0, 0)
-    EasyMock.expect(transactionManager.getTransactionState(transactionId))
-      .andReturn(Some(metadata))
+    val metadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds())
+    EasyMock.expect(transactionManager.getTransactionState(transactionalId))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))
 
+    val capturedNewMetadata: Capture[TransactionMetadataTransition] = EasyMock.newCapture()
     EasyMock.expect(transactionManager.appendTransactionToLog(
-      EasyMock.eq(transactionId),
-      EasyMock.anyObject(classOf[TransactionMetadata]),
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.capture(capturedNewMetadata),
       EasyMock.capture(capturedErrorsCallback)
     )).andAnswer(new IAnswer[Unit] {
       override def answer(): Unit = {
+        metadata.completeTransitionTo(capturedNewMetadata.getValue)
         capturedErrorsCallback.getValue.apply(Errors.NONE)
       }
     })
 
     EasyMock.replay(transactionManager)
 
-    coordinator.handleInitPid(transactionId, 10, initPidMockCallback)
+    val newTxnTimeoutMs = 10
+    coordinator.handleInitPid(transactionalId, newTxnTimeoutMs, initPidMockCallback)
 
-    assertEquals(InitPidResult(0, 1, Errors.NONE), result)
-    assertEquals(10, metadata.txnTimeoutMs)
-    assertEquals(time.milliseconds(), metadata.lastUpdateTimestamp)
-    assertEquals(1, metadata.producerEpoch)
-    assertEquals(0, metadata.pid)
+    assertEquals(InitPidResult(pid, (epoch + 1).toShort, Errors.NONE), result)
+    assertEquals(newTxnTimeoutMs, metadata.txnTimeoutMs)
+    assertEquals(time.milliseconds(), metadata.txnLastUpdateTimestamp)
+    assertEquals((epoch + 1).toShort, metadata.producerEpoch)
+    assertEquals(pid, metadata.producerId)
   }
 
-  private def mockPrepare(transactionState: TransactionState, runCallback: Boolean = false) = {
-    val originalMetadata = new TransactionMetadata(pid,
-      epoch,
-      txnTimeoutMs,
-      Ongoing,
-      collection.mutable.Set.empty[TopicPartition],
-      0,
-      time.milliseconds())
-
-    val prepareCommitMetadata = new TransactionMetadata(pid,
-      epoch,
-      txnTimeoutMs,
-      transactionState,
-      collection.mutable.Set.empty[TopicPartition],
-      0,
-      time.milliseconds())
+  private def mockPrepare(transactionState: TransactionState, runCallback: Boolean = false): TransactionMetadata = {
+    val now = time.milliseconds()
+    val originalMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, Ongoing, partitions, now, now)
 
     EasyMock.expect(transactionManager.isCoordinatorFor(transactionalId))
       .andReturn(true)
+      .anyTimes()
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(originalMetadata))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata)))
       .once()
-
-    EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
-      EasyMock.eq(prepareCommitMetadata),
+    EasyMock.expect(transactionManager.appendTransactionToLog(
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(originalMetadata.copy().prepareAbortOrCommit(transactionState, now)),
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {
-          if (runCallback) capturedErrorsCallback.getValue.apply(Errors.NONE)
+          if (runCallback)
+            capturedErrorsCallback.getValue.apply(Errors.NONE)
         }
       }).once()
-    prepareCommitMetadata
-  }
-
-  private def mockComplete(transactionState: TransactionState, appendError: Errors = Errors.NONE) = {
 
+    new TransactionMetadata(pid, epoch, txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds())
+  }
 
-    val prepareMetadata: TransactionMetadata = mockPrepare(transactionState, true)
-    val finalState = if (transactionState == PrepareAbort) CompleteAbort else CompleteCommit
+  private def mockComplete(transactionState: TransactionState, appendError: Errors = Errors.NONE): TransactionMetadata = {
+    val now = time.milliseconds()
+    val prepareMetadata = mockPrepare(transactionState, true)
 
-    EasyMock.expect(transactionManager.coordinatorEpochFor(transactionalId))
-      .andReturn(Some(0))
+    val (finalState, txnResult) = if (transactionState == PrepareAbort)
+      (CompleteAbort, TransactionResult.ABORT)
+    else
+      (CompleteCommit, TransactionResult.COMMIT)
 
-    EasyMock.expect(transactionMarkerChannelManager.addTxnMarkerRequest(
-      EasyMock.eq(0),
-      EasyMock.anyObject(),
-      EasyMock.anyInt(),
-      EasyMock.capture(capturedErrorsCallback)
-    )).andAnswer(new IAnswer[Unit] {
-      override def answer(): Unit = {
-        capturedErrorsCallback.getValue.apply(Errors.NONE)
-      }
-    })
+    val completedMetadata = new TransactionMetadata(pid, epoch, txnTimeoutMs, finalState,
+      collection.mutable.Set.empty[TopicPartition],
+      prepareMetadata.txnStartTimestamp,
+      prepareMetadata.txnLastUpdateTimestamp)
 
     EasyMock.expect(transactionManager.getTransactionState(transactionalId))
-      .andReturn(Some(prepareMetadata))
+      .andReturn(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, prepareMetadata)))
       .once()
 
-    val completedMetadata = new TransactionMetadata(pid,
-      epoch,
-      txnTimeoutMs,
-      finalState,
-      prepareMetadata.topicPartitions,
-      prepareMetadata.transactionStartTime,
-      prepareMetadata.lastUpdateTimestamp)
+    val newMetadata = prepareMetadata.copy().prepareComplete(now)
+    EasyMock.expect(transactionMarkerChannelManager.addTxnMarkersToSend(
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(txnResult),
+      EasyMock.eq(prepareMetadata),
+      EasyMock.eq(newMetadata))
+    ).once()
 
-    val firstAnswer = EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId),
-      EasyMock.eq(completedMetadata),
+    val firstAnswer = EasyMock.expect(transactionManager.appendTransactionToLog(
+      EasyMock.eq(transactionalId),
+      EasyMock.eq(coordinatorEpoch),
+      EasyMock.eq(newMetadata),
       EasyMock.capture(capturedErrorsCallback)))
       .andAnswer(new IAnswer[Unit] {
         override def answer(): Unit = {
@@ -804,18 +692,18 @@ class TransactionCoordinatorTest {
         }
       })
 
-     if(appendError != Errors.NONE && appendError != Errors.NOT_COORDINATOR) {
-        firstAnswer.andAnswer(new IAnswer[Unit] {
-          override def answer(): Unit = {
-            capturedErrorsCallback.getValue.apply(Errors.NONE)
-          }
-        })
-     }
-
+    // let it succeed next time
+    if (appendError != Errors.NONE && appendError != Errors.NOT_COORDINATOR) {
+      firstAnswer.andAnswer(new IAnswer[Unit] {
+        override def answer(): Unit = {
+          capturedErrorsCallback.getValue.apply(Errors.NONE)
+        }
+      })
+    }
 
+    completedMetadata
   }
 
-
   def initPidMockCallback(ret: InitPidResult): Unit = {
     result = ret
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/794e6dbd/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
index cfb4a99..fe750b8 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
@@ -43,7 +43,7 @@ class TransactionLogTest extends JUnitSuite {
     txnMetadata.addPartitions(topicPartitions)
 
     intercept[IllegalStateException] {
-      TransactionLog.valueToBytes(txnMetadata)
+      TransactionLog.valueToBytes(txnMetadata.prepareNoTransit())
     }
   }
 
@@ -71,7 +71,7 @@ class TransactionLogTest extends JUnitSuite {
         txnMetadata.addPartitions(topicPartitions)
 
       val keyBytes = TransactionLog.keyToBytes(transactionalId)
-      val valueBytes = TransactionLog.valueToBytes(txnMetadata)
+      val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit())
 
       new SimpleRecord(keyBytes, valueBytes)
     }.toSeq
@@ -87,10 +87,10 @@ class TransactionLogTest extends JUnitSuite {
           val transactionalId = pidKey.transactionalId
           val txnMetadata = TransactionLog.readMessageValue(record.value())
 
-          assertEquals(pidMappings(transactionalId), txnMetadata.pid)
+          assertEquals(pidMappings(transactionalId), txnMetadata.producerId)
           assertEquals(epoch, txnMetadata.producerEpoch)
           assertEquals(transactionTimeoutMs, txnMetadata.txnTimeoutMs)
-          assertEquals(transactionStates(txnMetadata.pid), txnMetadata.state)
+          assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state)
 
           if (txnMetadata.state.equals(Empty))
             assertEquals(Set.empty[TopicPartition], txnMetadata.topicPartitions)