You are viewing a plain text version of this content. The canonical link for it is here.
Posted to jira@kafka.apache.org by "ASF GitHub Bot (JIRA)" <ji...@apache.org> on 2018/01/09 00:16:17 UTC

[jira] [Commented] (KAFKA-6096) Add concurrent tests to exercise all paths in group/transaction managers

    [ https://issues.apache.org/jira/browse/KAFKA-6096?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16317394#comment-16317394 ] 

ASF GitHub Bot commented on KAFKA-6096:
---------------------------------------

hachikuji closed pull request #4122: KAFKA-6096: Add multi-threaded tests for group coordinator, txn manager
URL: https://github.com/apache/kafka/pull/4122
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
new file mode 100644
index 00000000000..0ecc3f538b1
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
@@ -0,0 +1,226 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.coordinator
+
+import java.util.{ Collections, Random }
+import java.util.concurrent.{ ConcurrentHashMap, Executors }
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.locks.Lock
+
+import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
+import kafka.log.Log
+import kafka.server._
+import kafka.utils._
+import kafka.utils.timer.MockTimer
+import kafka.zk.KafkaZkClient
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.record.{ MemoryRecords, RecordBatch, RecordsProcessingStats }
+import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
+import org.easymock.EasyMock
+import org.junit.{ After, Before }
+
+import scala.collection._
+import scala.collection.JavaConverters._
+
+abstract class AbstractCoordinatorConcurrencyTest[M <: CoordinatorMember] {
+
+  val nThreads = 5
+
+  val time = new MockTime
+  val timer = new MockTimer
+  val executor = Executors.newFixedThreadPool(nThreads)
+  val scheduler = new MockScheduler(time)
+  var replicaManager: TestReplicaManager = _
+  var zkClient: KafkaZkClient = _
+  val serverProps = TestUtils.createBrokerConfig(nodeId = 0, zkConnect = "")
+  val random = new Random
+
+  @Before
+  def setUp() {
+
+    replicaManager = EasyMock.partialMockBuilder(classOf[TestReplicaManager]).createMock()
+    replicaManager.createDelayedProducePurgatory(timer)
+
+    zkClient = EasyMock.createNiceMock(classOf[KafkaZkClient])
+  }
+
+  @After
+  def tearDown() {
+    EasyMock.reset(replicaManager)
+    if (executor != null)
+      executor.shutdownNow()
+  }
+
+  /**
+    * Verify that concurrent operations run in the normal sequence produce the expected results.
+    */
+  def verifyConcurrentOperations(createMembers: String => Set[M], operations: Seq[Operation]) {
+    OrderedOperationSequence(createMembers("verifyConcurrentOperations"), operations).run()
+  }
+
+  /**
+    * Verify that arbitrary operations run in some random sequence don't leave the coordinator
+    * in a bad state. Operations in the normal sequence should continue to work as expected.
+    */
+  def verifyConcurrentRandomSequences(createMembers: String => Set[M], operations: Seq[Operation]) {
+    EasyMock.reset(replicaManager)
+    for (i <- 0 to 10) {
+      // Run some random operations
+      RandomOperationSequence(createMembers(s"random$i"), operations).run()
+
+      // Check that proper sequences still work correctly
+      OrderedOperationSequence(createMembers(s"ordered$i"), operations).run()
+    }
+  }
+
+  def verifyConcurrentActions(actions: Set[Action]) {
+    val futures = actions.map(executor.submit)
+    futures.map(_.get)
+    enableCompletion()
+    actions.foreach(_.await())
+  }
+
+  def enableCompletion(): Unit = {
+    replicaManager.tryCompleteDelayedRequests()
+    scheduler.tick()
+  }
+
+  abstract class OperationSequence(members: Set[M], operations: Seq[Operation]) {
+    def actionSequence: Seq[Set[Action]]
+    def run(): Unit = {
+      actionSequence.foreach(verifyConcurrentActions)
+    }
+  }
+
+  case class OrderedOperationSequence(members: Set[M], operations: Seq[Operation])
+    extends OperationSequence(members, operations) {
+    override def actionSequence: Seq[Set[Action]] = {
+      operations.map { op =>
+        members.map(op.actionWithVerify)
+      }
+    }
+  }
+
+  case class RandomOperationSequence(members: Set[M], operations: Seq[Operation])
+    extends OperationSequence(members, operations) {
+    val opCount = operations.length
+    def actionSequence: Seq[Set[Action]] = {
+      (0 to opCount).map { _ =>
+        members.map { member =>
+          val op = operations(random.nextInt(opCount))
+          op.actionNoVerify(member) // Don't wait or verify since these operations may block
+        }
+      }
+    }
+  }
+
+  abstract class Operation {
+    def run(member: M): Unit
+    def awaitAndVerify(member: M): Unit
+    def actionWithVerify(member: M): Action = {
+      new Action() {
+        def run(): Unit = Operation.this.run(member)
+        def await(): Unit = awaitAndVerify(member)
+      }
+    }
+    def actionNoVerify(member: M): Action = {
+      new Action() {
+        def run(): Unit = Operation.this.run(member)
+        def await(): Unit = timer.advanceClock(100) // Don't wait since operation may block
+      }
+    }
+  }
+}
+
+object AbstractCoordinatorConcurrencyTest {
+
+  trait Action extends Runnable {
+    def await(): Unit
+  }
+
+  trait CoordinatorMember {
+  }
+
+  class TestReplicaManager extends ReplicaManager(
+    null, null, null, null, null, null, null, null, null, null, null, null, null, null, None) {
+
+    var producePurgatory: DelayedOperationPurgatory[DelayedProduce] = _
+    var watchKeys: mutable.Set[TopicPartitionOperationKey] = _
+    def createDelayedProducePurgatory(timer: MockTimer): Unit = {
+      producePurgatory = new DelayedOperationPurgatory[DelayedProduce]("Produce", timer, 1, reaperEnabled = false)
+      watchKeys = Collections.newSetFromMap(new ConcurrentHashMap[TopicPartitionOperationKey, java.lang.Boolean]()).asScala
+    }
+    def tryCompleteDelayedRequests(): Unit = {
+      watchKeys.map(producePurgatory.checkAndComplete)
+    }
+
+    override def appendRecords(timeout: Long,
+                               requiredAcks: Short,
+                               internalTopicsAllowed: Boolean,
+                               isFromClient: Boolean,
+                               entriesPerPartition: Map[TopicPartition, MemoryRecords],
+                               responseCallback: Map[TopicPartition, PartitionResponse] => Unit,
+                               delayedProduceLock: Option[Lock] = None,
+                               processingStatsCallback: Map[TopicPartition, RecordsProcessingStats] => Unit = _ => ()) {
+
+      if (entriesPerPartition.isEmpty)
+        return
+      val produceMetadata = ProduceMetadata(1, entriesPerPartition.map {
+        case (tp, _) =>
+          (tp, ProducePartitionStatus(0L, new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L)))
+      })
+      val delayedProduce = new DelayedProduce(5, produceMetadata, this, responseCallback, delayedProduceLock) {
+        // Complete produce requests after a few attempts to trigger delayed produce from different threads
+        val completeAttempts = new AtomicInteger
+        override def tryComplete(): Boolean = {
+          if (completeAttempts.incrementAndGet() >= 3)
+            forceComplete()
+          else
+            false
+        }
+        override def onComplete() {
+          responseCallback(entriesPerPartition.map {
+            case (tp, _) =>
+              (tp, new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L))
+          })
+        }
+      }
+      val producerRequestKeys = entriesPerPartition.keys.map(new TopicPartitionOperationKey(_)).toSeq
+      watchKeys ++= producerRequestKeys
+      producePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)
+      tryCompleteDelayedRequests()
+    }
+    override def getMagic(topicPartition: TopicPartition): Option[Byte] = {
+      Some(RecordBatch.MAGIC_VALUE_V2)
+    }
+    @volatile var logs: mutable.Map[TopicPartition, (Log, Long)] = _
+    def getOrCreateLogs(): mutable.Map[TopicPartition, (Log, Long)] = {
+      if (logs == null)
+        logs = mutable.Map[TopicPartition, (Log, Long)]()
+      logs
+    }
+    def updateLog(topicPartition: TopicPartition, log: Log, endOffset: Long): Unit = {
+      getOrCreateLogs().put(topicPartition, (log, endOffset))
+    }
+    override def getLog(topicPartition: TopicPartition): Option[Log] =
+      getOrCreateLogs().get(topicPartition).map(l => l._1)
+    override def getLogEndOffset(topicPartition: TopicPartition): Option[Long] =
+      getOrCreateLogs().get(topicPartition).map(l => l._2)
+  }
+}
diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
new file mode 100644
index 00000000000..44e13560b00
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
@@ -0,0 +1,310 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.coordinator.group
+
+import java.util.concurrent.{ ConcurrentHashMap, TimeUnit }
+
+import kafka.common.OffsetAndMetadata
+import kafka.coordinator.AbstractCoordinatorConcurrencyTest
+import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
+import kafka.coordinator.group.GroupCoordinatorConcurrencyTest._
+import kafka.server.{ DelayedOperationPurgatory, KafkaConfig }
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.internals.Topic
+import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.requests.{ JoinGroupRequest, TransactionResult }
+import org.easymock.EasyMock
+import org.junit.Assert._
+import org.junit.{ After, Before, Test }
+
+import scala.collection._
+import scala.concurrent.duration.Duration
+import scala.concurrent.{ Await, Future, Promise, TimeoutException }
+
+class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest[GroupMember] {
+
+  private val protocolType = "consumer"
+  private val metadata = Array[Byte]()
+  private val protocols = List(("range", metadata))
+
+  private val nGroups = nThreads * 10
+  private val nMembersPerGroup = nThreads * 5
+  private val numPartitions = 2
+
+  private val allOperations = Seq(
+      new JoinGroupOperation,
+      new SyncGroupOperation,
+      new CommitOffsetsOperation,
+      new HeartbeatOperation,
+      new LeaveGroupOperation
+    )
+  private val allOperationsWithTxn = Seq(
+    new JoinGroupOperation,
+    new SyncGroupOperation,
+    new CommitTxnOffsetsOperation,
+    new CompleteTxnOperation,
+    new HeartbeatOperation,
+    new LeaveGroupOperation
+  )
+
+  var groupCoordinator: GroupCoordinator = _
+
+  @Before
+  override def setUp() {
+    super.setUp()
+
+    EasyMock.expect(zkClient.getTopicPartitionCount(Topic.GROUP_METADATA_TOPIC_NAME))
+      .andReturn(Some(numPartitions))
+      .anyTimes()
+    EasyMock.replay(zkClient)
+
+    serverProps.setProperty(KafkaConfig.GroupMinSessionTimeoutMsProp, ConsumerMinSessionTimeout.toString)
+    serverProps.setProperty(KafkaConfig.GroupMaxSessionTimeoutMsProp, ConsumerMaxSessionTimeout.toString)
+    serverProps.setProperty(KafkaConfig.GroupInitialRebalanceDelayMsProp, GroupInitialRebalanceDelay.toString)
+
+    val config = KafkaConfig.fromProps(serverProps)
+
+    val heartbeatPurgatory = new DelayedOperationPurgatory[DelayedHeartbeat]("Heartbeat", timer, config.brokerId, reaperEnabled = false)
+    val joinPurgatory = new DelayedOperationPurgatory[DelayedJoin]("Rebalance", timer, config.brokerId, reaperEnabled = false)
+
+    groupCoordinator = GroupCoordinator(config, zkClient, replicaManager, heartbeatPurgatory, joinPurgatory, timer.time)
+    groupCoordinator.startup(false)
+  }
+
+  @After
+  override def tearDown() {
+    try {
+      if (groupCoordinator != null)
+        groupCoordinator.shutdown()
+    } finally {
+      super.tearDown()
+    }
+  }
+
+  def createGroupMembers(groupPrefix: String): Set[GroupMember] = {
+    (0 until nGroups).flatMap { i =>
+      new Group(s"$groupPrefix$i", nMembersPerGroup, groupCoordinator, replicaManager).members
+    }.toSet
+  }
+
+  @Test
+  def testConcurrentGoodPathSequence() {
+    verifyConcurrentOperations(createGroupMembers, allOperations)
+  }
+
+  @Test
+  def testConcurrentTxnGoodPathSequence() {
+    verifyConcurrentOperations(createGroupMembers, allOperationsWithTxn)
+  }
+
+  @Test
+  def testConcurrentRandomSequence() {
+    verifyConcurrentRandomSequences(createGroupMembers, allOperationsWithTxn)
+  }
+
+
+  abstract class GroupOperation[R, C] extends Operation {
+    val responseFutures = new ConcurrentHashMap[GroupMember, Future[R]]()
+
+    def setUpCallback(member: GroupMember): C = {
+      val responsePromise = Promise[R]
+      val responseFuture = responsePromise.future
+      responseFutures.put(member, responseFuture)
+      responseCallback(responsePromise)
+    }
+    def responseCallback(responsePromise: Promise[R]): C
+
+    override def run(member: GroupMember): Unit = {
+      val responseCallback = setUpCallback(member)
+      runWithCallback(member, responseCallback)
+    }
+
+    def runWithCallback(member: GroupMember, responseCallback: C): Unit
+
+    def await(member: GroupMember, timeoutMs: Long): R = {
+      var retries = (timeoutMs + 10) / 10
+      val responseFuture = responseFutures.get(member)
+      while (retries > 0) {
+        timer.advanceClock(10)
+        try {
+          return Await.result(responseFuture, Duration(10, TimeUnit.MILLISECONDS))
+        } catch {
+          case _: TimeoutException =>
+        }
+        retries -= 1
+      }
+      throw new TimeoutException(s"Operation did not complete within $timeoutMs millis")
+    }
+  }
+
+
+  class JoinGroupOperation extends GroupOperation[JoinGroupResult, JoinGroupCallback] {
+    override def responseCallback(responsePromise: Promise[JoinGroupResult]): JoinGroupCallback = {
+      val callback: JoinGroupCallback = responsePromise.success(_)
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: JoinGroupCallback): Unit = {
+      groupCoordinator.handleJoinGroup(member.groupId, member.memberId, "clientId", "clientHost",
+       DefaultRebalanceTimeout, DefaultSessionTimeout,
+       protocolType, protocols, responseCallback)
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+       val joinGroupResult = await(member, DefaultRebalanceTimeout)
+       assertEquals(Errors.NONE, joinGroupResult.error)
+       member.memberId = joinGroupResult.memberId
+       member.generationId = joinGroupResult.generationId
+    }
+  }
+
+  class SyncGroupOperation extends GroupOperation[SyncGroupCallbackParams, SyncGroupCallback] {
+    override def responseCallback(responsePromise: Promise[SyncGroupCallbackParams]): SyncGroupCallback = {
+      val callback: SyncGroupCallback = (assignment, error) =>
+        responsePromise.success((assignment, error))
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: SyncGroupCallback): Unit = {
+      if (member.leader) {
+        groupCoordinator.handleSyncGroup(member.groupId, member.generationId, member.memberId,
+            member.group.assignment, responseCallback)
+      } else {
+         groupCoordinator.handleSyncGroup(member.groupId, member.generationId, member.memberId,
+             Map.empty[String, Array[Byte]], responseCallback)
+      }
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+       val result = await(member, DefaultSessionTimeout)
+       assertEquals(Errors.NONE, result._2)
+    }
+  }
+
+  class HeartbeatOperation extends GroupOperation[HeartbeatCallbackParams, HeartbeatCallback] {
+    override def responseCallback(responsePromise: Promise[HeartbeatCallbackParams]): HeartbeatCallback = {
+      val callback: HeartbeatCallback = error => responsePromise.success(error)
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: HeartbeatCallback): Unit = {
+      groupCoordinator.handleHeartbeat( member.groupId, member.memberId,  member.generationId, responseCallback)
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+       val error = await(member, DefaultSessionTimeout)
+       assertEquals(Errors.NONE, error)
+    }
+  }
+  class CommitOffsetsOperation extends GroupOperation[CommitOffsetCallbackParams, CommitOffsetCallback] {
+    override def responseCallback(responsePromise: Promise[CommitOffsetCallbackParams]): CommitOffsetCallback = {
+      val callback: CommitOffsetCallback = offsets => responsePromise.success(offsets)
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: CommitOffsetCallback): Unit = {
+      val tp = new TopicPartition("topic", 0)
+      val offsets = immutable.Map(tp -> OffsetAndMetadata(1))
+      groupCoordinator.handleCommitOffsets(member.groupId, member.memberId, member.generationId,
+          offsets, responseCallback)
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+       val offsets = await(member, 500)
+       offsets.foreach { case (_, error) => assertEquals(Errors.NONE, error) }
+    }
+  }
+
+  class CommitTxnOffsetsOperation extends CommitOffsetsOperation {
+    override def runWithCallback(member: GroupMember, responseCallback: CommitOffsetCallback): Unit = {
+      val tp = new TopicPartition("topic", 0)
+      val offsets = immutable.Map(tp -> OffsetAndMetadata(1))
+      val producerId = 1000L
+      val producerEpoch : Short = 2
+      groupCoordinator.handleTxnCommitOffsets(member.group.groupId,
+          producerId, producerEpoch, offsets, responseCallback)
+    }
+  }
+
+  class CompleteTxnOperation extends GroupOperation[CompleteTxnCallbackParams, CompleteTxnCallback] {
+    override def responseCallback(responsePromise: Promise[CompleteTxnCallbackParams]): CompleteTxnCallback = {
+      val callback: CompleteTxnCallback = error => responsePromise.success(error)
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: CompleteTxnCallback): Unit = {
+      val producerId = 1000L
+      val offsetsPartitions = (0 to numPartitions).map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, _))
+      groupCoordinator.handleTxnCompletion(producerId, offsetsPartitions, transactionResult(member.group.groupId))
+      responseCallback(Errors.NONE)
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+      val error = await(member, 500)
+      assertEquals(Errors.NONE, error)
+    }
+    // Test both commit and abort. Group ids used in the test have the format <prefix><index>
+    // Use the last digit of the index to decide between commit and abort.
+    private def transactionResult(groupId: String): TransactionResult = {
+      val lastDigit = groupId(groupId.length - 1).toInt
+      if (lastDigit % 2 == 0) TransactionResult.COMMIT else TransactionResult.ABORT
+    }
+  }
+
+  class LeaveGroupOperation extends GroupOperation[LeaveGroupCallbackParams, LeaveGroupCallback] {
+    override def responseCallback(responsePromise: Promise[LeaveGroupCallbackParams]): LeaveGroupCallback = {
+      val callback: LeaveGroupCallback = error => responsePromise.success(error)
+      callback
+    }
+    override def runWithCallback(member: GroupMember, responseCallback: LeaveGroupCallback): Unit = {
+      groupCoordinator.handleLeaveGroup(member.group.groupId, member.memberId, responseCallback)
+    }
+    override def awaitAndVerify(member: GroupMember): Unit = {
+       val error = await(member, DefaultSessionTimeout)
+       assertEquals(Errors.NONE, error)
+    }
+  }
+}
+
+object GroupCoordinatorConcurrencyTest {
+
+
+  type JoinGroupCallback = JoinGroupResult => Unit
+  type SyncGroupCallbackParams = (Array[Byte], Errors)
+  type SyncGroupCallback = (Array[Byte], Errors) => Unit
+  type HeartbeatCallbackParams = Errors
+  type HeartbeatCallback = Errors => Unit
+  type CommitOffsetCallbackParams = Map[TopicPartition, Errors]
+  type CommitOffsetCallback = Map[TopicPartition, Errors] => Unit
+  type LeaveGroupCallbackParams = Errors
+  type LeaveGroupCallback = Errors => Unit
+  type CompleteTxnCallbackParams = Errors
+  type CompleteTxnCallback = Errors => Unit
+
+  private val ConsumerMinSessionTimeout = 10
+  private val ConsumerMaxSessionTimeout = 120 * 1000
+  private val DefaultRebalanceTimeout = 60 * 1000
+  private val DefaultSessionTimeout = 60 * 1000
+  private val GroupInitialRebalanceDelay = 50
+
+  class Group(val groupId: String, nMembers: Int,
+      groupCoordinator: GroupCoordinator, replicaManager: TestReplicaManager) {
+    val groupPartitionId = groupCoordinator.partitionFor(groupId)
+    groupCoordinator.groupManager.addPartitionOwnership(groupPartitionId)
+    val members = (0 until nMembers).map { i =>
+      new GroupMember(this, groupPartitionId, i == 0)
+    }
+    def assignment = members.map { m => (m.memberId, Array[Byte]()) }.toMap
+  }
+
+  class GroupMember(val group: Group, val groupPartitionId: Int, val leader: Boolean) extends CoordinatorMember {
+    @volatile var memberId: String = JoinGroupRequest.UNKNOWN_MEMBER_ID
+    @volatile var generationId: Int = -1
+    def groupId: String = group.groupId
+  }
+}
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
new file mode 100644
index 00000000000..046741afa1e
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -0,0 +1,388 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package kafka.coordinator.transaction
+
+import java.nio.ByteBuffer
+
+import kafka.coordinator.AbstractCoordinatorConcurrencyTest
+import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
+import kafka.coordinator.transaction.TransactionCoordinatorConcurrencyTest._
+import kafka.log.Log
+import kafka.server.{ DelayedOperationPurgatory, FetchDataInfo, KafkaConfig, LogOffsetMetadata, MetadataCache }
+import kafka.utils.timer.MockTimer
+import kafka.utils.{ Pool, TestUtils}
+
+import org.apache.kafka.clients.{ ClientResponse, NetworkClient }
+import org.apache.kafka.common.{ Node, TopicPartition }
+import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME
+import org.apache.kafka.common.protocol.{ ApiKeys, Errors }
+import org.apache.kafka.common.record.{ CompressionType, FileRecords, MemoryRecords, SimpleRecord }
+import org.apache.kafka.common.requests._
+import org.apache.kafka.common.utils.{ LogContext, MockTime }
+
+import org.easymock.EasyMock
+import org.junit.Assert._
+import org.junit.{ After, Before, Test }
+
+import scala.collection.Map
+import scala.collection.mutable
+import scala.collection.JavaConverters._
+
+class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest[Transaction] {
+  private val nTransactions = nThreads * 10
+  private val coordinatorEpoch = 10
+  private val numPartitions = nThreads * 5
+
+  private val txnConfig = TransactionConfig()
+  private var transactionCoordinator: TransactionCoordinator = _
+  private var txnStateManager: TransactionStateManager = _
+  private var txnMarkerChannelManager: TransactionMarkerChannelManager = _
+
+  private val allOperations = Seq(
+      new InitProducerIdOperation,
+      new AddPartitionsToTxnOperation(Set(new TopicPartition("topic", 0))),
+      new EndTxnOperation)
+
+  private val allTransactions = mutable.Set[Transaction]()
+  private val txnRecordsByPartition: Map[Int, mutable.ArrayBuffer[SimpleRecord]] =
+    (0 until numPartitions).map { i => (i, mutable.ArrayBuffer[SimpleRecord]()) }.toMap
+
+  @Before
+  override def setUp() {
+    super.setUp()
+
+    EasyMock.expect(zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME))
+      .andReturn(Some(numPartitions))
+      .anyTimes()
+    EasyMock.replay(zkClient)
+
+    txnStateManager = new TransactionStateManager(0, zkClient, scheduler, replicaManager, txnConfig, time)
+    for (i <- 0 until numPartitions)
+      txnStateManager.addLoadedTransactionsToCache(i, coordinatorEpoch, new Pool[String, TransactionMetadata]())
+
+    val producerId = 11
+    val pidManager: ProducerIdManager = EasyMock.createNiceMock(classOf[ProducerIdManager])
+    EasyMock.expect(pidManager.generateProducerId())
+      .andReturn(producerId)
+      .anyTimes()
+    val txnMarkerPurgatory = new DelayedOperationPurgatory[DelayedTxnMarker]("txn-purgatory-name",
+      new MockTimer,
+      reaperEnabled = false)
+    val brokerNode = new Node(0, "host", 10)
+    val metadataCache = EasyMock.createNiceMock(classOf[MetadataCache])
+    EasyMock.expect(metadataCache.getPartitionLeaderEndpoint(
+      EasyMock.anyString(),
+      EasyMock.anyInt(),
+      EasyMock.anyObject())
+    ).andReturn(Some(brokerNode)).anyTimes()
+    val networkClient = EasyMock.createNiceMock(classOf[NetworkClient])
+    txnMarkerChannelManager = new TransactionMarkerChannelManager(
+      KafkaConfig.fromProps(serverProps),
+      metadataCache,
+      networkClient,
+      txnStateManager,
+      txnMarkerPurgatory,
+      time) {
+        override def shutdown(): Unit = {
+          txnMarkerPurgatory.shutdown()
+        }
+    }
+
+    transactionCoordinator = new TransactionCoordinator(brokerId = 0,
+      txnConfig,
+      scheduler,
+      pidManager,
+      txnStateManager,
+      txnMarkerChannelManager,
+      time,
+      new LogContext)
+    EasyMock.replay(pidManager)
+    EasyMock.replay(metadataCache)
+    EasyMock.replay(networkClient)
+  }
+
+  @After
+  override def tearDown() {
+    try {
+      EasyMock.reset(zkClient, replicaManager)
+      transactionCoordinator.shutdown()
+    } finally {
+      super.tearDown()
+    }
+  }
+
+  @Test
+  def testConcurrentGoodPathSequence(): Unit = {
+    verifyConcurrentOperations(createTransactions, allOperations)
+  }
+
+  @Test
+  def testConcurrentRandomSequences(): Unit = {
+    verifyConcurrentRandomSequences(createTransactions, allOperations)
+  }
+
+  /**
+    * Concurrently load one set of transaction state topic partitions and unload another
+    * set of partitions. This tests partition leader changes of transaction state topic
+    * that are handled by different threads concurrently. Verifies that the metadata of
+    * unloaded partitions are removed from the transaction manager and that the transactions
+    * from the newly loaded partitions are loaded correctly.
+    */
+  @Test
+  def testConcurrentLoadUnloadPartitions(): Unit = {
+    val partitionsToLoad = (0 until numPartitions / 2).toSet
+    val partitionsToUnload = (numPartitions / 2 until numPartitions).toSet
+    verifyConcurrentActions(loadUnloadActions(partitionsToLoad, partitionsToUnload))
+  }
+
+  /**
+    * Concurrently load one set of transaction state topic partitions, unload a second set
+    * of partitions and expire transactions on a third set of partitions. This tests partition
+    * leader changes of transaction state topic that are handled by different threads concurrently
+    * while expiry is performed on another thread. Verifies the state of transactions on all the partitions.
+    */
+  @Test
+  def testConcurrentTransactionExpiration(): Unit = {
+    val partitionsToLoad = (0 until numPartitions / 3).toSet
+    val partitionsToUnload = (numPartitions / 3 until numPartitions * 2 / 3).toSet
+    val partitionsWithExpiringTxn = (numPartitions * 2 / 3 until numPartitions).toSet
+    val expiringTransactions = allTransactions.filter { txn =>
+      partitionsWithExpiringTxn.contains(txnStateManager.partitionFor(txn.transactionalId))
+    }.toSet
+    val expireAction = new ExpireTransactionsAction(expiringTransactions)
+    verifyConcurrentActions(loadUnloadActions(partitionsToLoad, partitionsToUnload) + expireAction)
+  }
+
+  override def enableCompletion(): Unit = {
+    super.enableCompletion()
+
+    def createResponse(request: WriteTxnMarkersRequest): WriteTxnMarkersResponse  = {
+      val pidErrorMap = request.markers.asScala.map { marker =>
+        (marker.producerId.asInstanceOf[java.lang.Long], marker.partitions.asScala.map { tp => (tp, Errors.NONE) }.toMap.asJava)
+      }.toMap.asJava
+      new WriteTxnMarkersResponse(pidErrorMap)
+    }
+    synchronized {
+      txnMarkerChannelManager.generateRequests().foreach { requestAndHandler =>
+        val request = requestAndHandler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()
+        val response = createResponse(request)
+        requestAndHandler.handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1),
+          null, null, 0, 0, false, null, response))
+      }
+    }
+  }
+
+  /**
+    * Concurrently load `partitionsToLoad` and unload `partitionsToUnload`. Before the concurrent operations
+    * are run `partitionsToLoad` must be unloaded first since all partitions were loaded during setUp.
+    */
+  private def loadUnloadActions(partitionsToLoad: Set[Int], partitionsToUnload: Set[Int]): Set[Action] = {
+    val transactions = (1 to 10).flatMap(i => createTransactions(s"testConcurrentLoadUnloadPartitions$i-")).toSet
+    transactions.foreach(txn => prepareTransaction(txn))
+    val unload = partitionsToLoad.map(new UnloadTxnPartitionAction(_))
+    unload.foreach(_.run())
+    unload.foreach(_.await())
+    partitionsToLoad.map(new LoadTxnPartitionAction(_)) ++ partitionsToUnload.map(new UnloadTxnPartitionAction(_))
+  }
+
+  private def createTransactions(txnPrefix: String): Set[Transaction] = {
+    val transactions = (0 until nTransactions).map { i => new Transaction(s"$txnPrefix$i", i, time) }
+    allTransactions ++= transactions
+    transactions.toSet
+  }
+
+  private def verifyTransaction(txn: Transaction, expectedState: TransactionState): Unit = {
+    val (metadata, success) = TestUtils.computeUntilTrue({
+      enableCompletion()
+      transactionMetadata(txn)
+    })(metadata => metadata.nonEmpty && metadata.forall(m => m.state == expectedState && m.pendingState.isEmpty))
+    assertTrue(s"Invalid metadata state $metadata", success)
+  }
+
+  private def transactionMetadata(txn: Transaction): Option[TransactionMetadata] = {
+    txnStateManager.getTransactionState(txn.transactionalId) match {
+      case Left(error) =>
+        if (error == Errors.NOT_COORDINATOR)
+          None
+        else
+          throw new AssertionError(s"Unexpected transaction error $error for $txn")
+      case Right(Some(metadata)) =>
+        Some(metadata.transactionMetadata)
+      case Right(None) =>
+        None
+    }
+  }
+
+  private def prepareTransaction(txn: Transaction): Unit = {
+    val partitionId = txnStateManager.partitionFor(txn.transactionalId)
+    val txnRecords = txnRecordsByPartition(partitionId)
+    val initPidOp = new InitProducerIdOperation()
+    val addPartitionsOp = new AddPartitionsToTxnOperation(Set(new TopicPartition("topic", 0)))
+      initPidOp.run(txn)
+      initPidOp.awaitAndVerify(txn)
+      addPartitionsOp.run(txn)
+      addPartitionsOp.awaitAndVerify(txn)
+
+      val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn"))
+      txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit()))
+
+      txnMetadata.state = PrepareCommit
+      txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit()))
+
+      prepareTxnLog(partitionId)
+  }
+
+  private def prepareTxnLog(partitionId: Int): Unit = {
+
+    val logMock =  EasyMock.mock(classOf[Log])
+    val fileRecordsMock = EasyMock.mock(classOf[FileRecords])
+
+    val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId)
+    val startOffset = replicaManager.getLogEndOffset(topicPartition).getOrElse(20L)
+    val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecordsByPartition(partitionId): _*)
+    val endOffset = startOffset + records.records.asScala.size
+
+    EasyMock.expect(logMock.logStartOffset).andStubReturn(startOffset)
+    EasyMock.expect(logMock.read(EasyMock.eq(startOffset), EasyMock.anyInt(), EasyMock.eq(None),
+      EasyMock.eq(true), EasyMock.eq(IsolationLevel.READ_UNCOMMITTED)))
+      .andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), fileRecordsMock))
+    EasyMock.expect(fileRecordsMock.readInto(EasyMock.anyObject(classOf[ByteBuffer]), EasyMock.anyInt()))
+      .andReturn(records.buffer)
+
+    EasyMock.replay(logMock, fileRecordsMock)
+    synchronized {
+      replicaManager.updateLog(topicPartition, logMock, endOffset)
+    }
+  }
+
+  abstract class TxnOperation[R] extends Operation {
+    @volatile var result: Option[R] = None
+    def resultCallback(r: R): Unit = this.result = Some(r)
+  }
+
+  class InitProducerIdOperation extends TxnOperation[InitProducerIdResult] {
+    override def run(txn: Transaction): Unit = {
+      transactionCoordinator.handleInitProducerId(txn.transactionalId, 60000, resultCallback)
+    }
+    override def awaitAndVerify(txn: Transaction): Unit = {
+      val initPidResult = result.getOrElse(throw new IllegalStateException("InitProducerId has not completed"))
+      assertEquals(Errors.NONE, initPidResult.error)
+      verifyTransaction(txn, Empty)
+    }
+  }
+
+  class AddPartitionsToTxnOperation(partitions: Set[TopicPartition]) extends TxnOperation[Errors] {
+    override def run(txn: Transaction): Unit = {
+      transactionMetadata(txn).foreach { txnMetadata =>
+        transactionCoordinator.handleAddPartitionsToTransaction(txn.transactionalId,
+            txnMetadata.producerId,
+            txnMetadata.producerEpoch,
+            partitions,
+            resultCallback)
+      }
+    }
+    override def awaitAndVerify(txn: Transaction): Unit = {
+      val error = result.getOrElse(throw new IllegalStateException("AddPartitionsToTransaction has not completed"))
+      assertEquals(Errors.NONE, error)
+      verifyTransaction(txn, Ongoing)
+    }
+  }
+
+  class EndTxnOperation extends TxnOperation[Errors] {
+    override def run(txn: Transaction): Unit = {
+      transactionMetadata(txn).foreach { txnMetadata =>
+        transactionCoordinator.handleEndTransaction(txn.transactionalId,
+          txnMetadata.producerId,
+          txnMetadata.producerEpoch,
+          transactionResult(txn),
+          resultCallback)
+      }
+    }
+    override def awaitAndVerify(txn: Transaction): Unit = {
+      val error = result.getOrElse(throw new IllegalStateException("EndTransaction has not completed"))
+      if (!txn.ended) {
+        txn.ended = true
+        assertEquals(Errors.NONE, error)
+        val expectedState = if (transactionResult(txn) == TransactionResult.COMMIT) CompleteCommit else CompleteAbort
+        verifyTransaction(txn, expectedState)
+      } else
+        assertEquals(Errors.INVALID_TXN_STATE, error)
+    }
+    // Test both commit and abort. Transactional ids used in the test have the format <prefix><index>
+    // Use the last digit of the index to decide between commit and abort.
+    private def transactionResult(txn: Transaction): TransactionResult = {
+      val txnId = txn.transactionalId
+      val lastDigit = txnId(txnId.length - 1).toInt
+      if (lastDigit % 2 == 0) TransactionResult.COMMIT else TransactionResult.ABORT
+    }
+  }
+
+  class LoadTxnPartitionAction(txnTopicPartitionId: Int) extends Action {
+    override def run(): Unit = {
+      transactionCoordinator.handleTxnImmigration(txnTopicPartitionId, coordinatorEpoch)
+    }
+    override def await(): Unit = {
+      allTransactions.foreach { txn =>
+        if (txnStateManager.partitionFor(txn.transactionalId) == txnTopicPartitionId) {
+          verifyTransaction(txn, CompleteCommit)
+        }
+      }
+    }
+  }
+
+  class UnloadTxnPartitionAction(txnTopicPartitionId: Int) extends Action {
+    val txnRecords: mutable.ArrayBuffer[SimpleRecord] = mutable.ArrayBuffer[SimpleRecord]()
+    override def run(): Unit = {
+      transactionCoordinator.handleTxnEmigration(txnTopicPartitionId, coordinatorEpoch)
+    }
+    override def await(): Unit = {
+      allTransactions.foreach { txn =>
+        if (txnStateManager.partitionFor(txn.transactionalId) == txnTopicPartitionId)
+          assertTrue("Transaction metadata not removed", transactionMetadata(txn).isEmpty)
+      }
+    }
+  }
+
+  class ExpireTransactionsAction(transactions: Set[Transaction]) extends Action {
+    override def run(): Unit = {
+      transactions.foreach { txn =>
+        transactionMetadata(txn).foreach { txnMetadata =>
+          txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs
+        }
+      }
+      txnStateManager.enableTransactionalIdExpiration()
+      time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs + 1)
+    }
+
+    override def await(): Unit = {
+      val (_, success) = TestUtils.computeUntilTrue({
+        replicaManager.tryCompleteDelayedRequests()
+        transactions.forall(txn => transactionMetadata(txn).isEmpty)
+      })(identity)
+      assertTrue("Transaction not expired", success)
+    }
+  }
+}
+
+object TransactionCoordinatorConcurrencyTest {
+
+  class Transaction(val transactionalId: String, producerId: Long, time: MockTime) extends CoordinatorMember {
+    val txnMessageKeyBytes: Array[Byte] = TransactionLog.keyToBytes(transactionalId)
+    @volatile var ended = false
+    override def toString: String = transactionalId
+  }
+}
diff --git a/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala b/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
index e4ac4fa1ec7..17ee578f6a5 100644
--- a/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
+++ b/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
@@ -28,8 +28,11 @@ class MockTimer extends Timer {
   def add(timerTask: TimerTask) {
     if (timerTask.delayMs <= 0)
       timerTask.run()
-    else
-      taskQueue.enqueue(new TimerTaskEntry(timerTask, timerTask.delayMs + time.milliseconds))
+    else {
+      taskQueue synchronized {
+        taskQueue.enqueue(new TimerTaskEntry(timerTask, timerTask.delayMs + time.milliseconds))
+      }
+    }
   }
 
   def advanceClock(timeoutMs: Long): Boolean = {
@@ -38,15 +41,25 @@ class MockTimer extends Timer {
     var executed = false
     val now = time.milliseconds
 
-    while (taskQueue.nonEmpty && now > taskQueue.head.expirationMs) {
-      val taskEntry = taskQueue.dequeue()
-      if (!taskEntry.cancelled) {
-        val task = taskEntry.timerTask
-        task.run()
-        executed = true
+    var hasMore = true
+    while (hasMore) {
+      hasMore = false
+      val head = taskQueue synchronized {
+        if (taskQueue.nonEmpty && now > taskQueue.head.expirationMs) {
+          val entry = Some(taskQueue.dequeue())
+          hasMore = taskQueue.nonEmpty
+          entry
+        } else
+          None
+      }
+      head.foreach { taskEntry =>
+        if (!taskEntry.cancelled) {
+          val task = taskEntry.timerTask
+          task.run()
+          executed = true
+        }
       }
     }
-
     executed
   }
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


> Add concurrent tests to exercise all paths in group/transaction managers
> ------------------------------------------------------------------------
>
>                 Key: KAFKA-6096
>                 URL: https://issues.apache.org/jira/browse/KAFKA-6096
>             Project: Kafka
>          Issue Type: Test
>          Components: core
>            Reporter: Rajini Sivaram
>            Assignee: Rajini Sivaram
>             Fix For: 1.1.0
>
>
> We don't have enough tests to test locking/deadlocks in GroupMetadataManager and TransactionManager. Since we have had a lot of deadlocks (KAFKA-5970, KAFKA-6042 etc.) which were not detected during testing, we should add more mock tests with concurrency to verify the locking.



--
This message was sent by Atlassian JIRA
(v6.4.14#64029)