You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2016/08/17 18:50:08 UTC

[1/4] kafka git commit: KAFKA-3888: send consumer heartbeats from a background thread (KIP-62)

Repository: kafka
Updated Branches:
  refs/heads/trunk 19997ede0 -> 40b1dd3f4


http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/main/scala/kafka/coordinator/MemberMetadata.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/MemberMetadata.scala b/core/src/main/scala/kafka/coordinator/MemberMetadata.scala
index 19c9e8e..6149276 100644
--- a/core/src/main/scala/kafka/coordinator/MemberMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/MemberMetadata.scala
@@ -55,6 +55,7 @@ private[coordinator] class MemberMetadata(val memberId: String,
                                           val groupId: String,
                                           val clientId: String,
                                           val clientHost: String,
+                                          val rebalanceTimeoutMs: Int,
                                           val sessionTimeoutMs: Int,
                                           val protocolType: String,
                                           var supportedProtocols: List[(String, Array[Byte])]) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/main/scala/kafka/server/KafkaApis.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala
index 6d38f85..bb219ca 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -890,8 +890,8 @@ class KafkaApis(val requestChannel: RequestChannel,
     // the callback for sending a join-group response
     def sendResponseCallback(joinResult: JoinGroupResult) {
       val members = joinResult.members map { case (memberId, metadataArray) => (memberId, ByteBuffer.wrap(metadataArray)) }
-      val responseBody = new JoinGroupResponse(joinResult.errorCode, joinResult.generationId, joinResult.subProtocol,
-        joinResult.memberId, joinResult.leaderId, members)
+      val responseBody = new JoinGroupResponse(request.header.apiVersion, joinResult.errorCode, joinResult.generationId,
+        joinResult.subProtocol, joinResult.memberId, joinResult.leaderId, members)
 
       trace("Sending join group response %s for correlation id %d to client %s."
         .format(responseBody, request.header.correlationId, request.header.clientId))
@@ -900,6 +900,7 @@ class KafkaApis(val requestChannel: RequestChannel,
 
     if (!authorize(request.session, Read, new Resource(Group, joinGroupRequest.groupId()))) {
       val responseBody = new JoinGroupResponse(
+        request.header.apiVersion,
         Errors.GROUP_AUTHORIZATION_FAILED.code,
         JoinGroupResponse.UNKNOWN_GENERATION_ID,
         JoinGroupResponse.UNKNOWN_PROTOCOL,
@@ -916,6 +917,7 @@ class KafkaApis(val requestChannel: RequestChannel,
         joinGroupRequest.memberId,
         request.header.clientId,
         request.session.clientAddress.toString,
+        joinGroupRequest.rebalanceTimeout,
         joinGroupRequest.sessionTimeout,
         joinGroupRequest.protocolType,
         protocols,

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
index 817cdf7..1a5f187 100644
--- a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
@@ -199,7 +199,7 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
   }
 
   private def createJoinGroupRequest = {
-    new JoinGroupRequest(group, 30000, "", "consumer",
+    new JoinGroupRequest(group, 10000, 60000, "", "consumer",
       List( new JoinGroupRequest.ProtocolMetadata("consumer-range",ByteBuffer.wrap("test".getBytes()))).asJava)
   }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala b/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
index f039750..c13bf58 100644
--- a/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
@@ -13,6 +13,7 @@
 package kafka.api
 
 import java.util
+
 import org.apache.kafka.clients.consumer._
 import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord}
 import org.apache.kafka.common.record.TimestampType
@@ -22,11 +23,14 @@ import kafka.utils.{TestUtils, Logging, ShutdownableThread}
 import kafka.common.Topic
 import kafka.server.KafkaConfig
 import java.util.ArrayList
+
 import org.junit.Assert._
 import org.junit.{Before, Test}
+
 import scala.collection.JavaConverters._
 import scala.collection.mutable.Buffer
 import org.apache.kafka.clients.producer.KafkaProducer
+import org.apache.kafka.common.errors.WakeupException
 
 /**
  * Integration tests for the new consumer that cover basic usage as well as server failures
@@ -82,112 +86,19 @@ abstract class BaseConsumerTest extends IntegrationTestHarness with Logging {
   }
 
   @Test
-  def testAutoCommitOnRebalance() {
-    val topic2 = "topic2"
-    TestUtils.createTopic(this.zkUtils, topic2, 2, serverCount, this.servers)
-
-    this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true")
-    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
-    consumers += consumer0
-
-    val numRecords = 10000
-    sendRecords(numRecords)
-
-    val rebalanceListener = new ConsumerRebalanceListener {
-      override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) = {
-        // keep partitions paused in this test so that we can verify the commits based on specific seeks
-        consumer0.pause(partitions)
-      }
-
-      override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) = {}
-    }
-
-    consumer0.subscribe(List(topic).asJava, rebalanceListener)
-
-    val assignment = Set(tp, tp2)
-    TestUtils.waitUntilTrue(() => {
-      consumer0.poll(50)
-      consumer0.assignment() == assignment.asJava
-    }, s"Expected partitions ${assignment.asJava} but actually got ${consumer0.assignment()}")
-
-    consumer0.seek(tp, 300)
-    consumer0.seek(tp2, 500)
-
-    // change subscription to trigger rebalance
-    consumer0.subscribe(List(topic, topic2).asJava, rebalanceListener)
-
-    val newAssignment = Set(tp, tp2, new TopicPartition(topic2, 0), new TopicPartition(topic2, 1))
-    TestUtils.waitUntilTrue(() => {
-      val records = consumer0.poll(50)
-      consumer0.assignment() == newAssignment.asJava
-    }, s"Expected partitions ${newAssignment.asJava} but actually got ${consumer0.assignment()}")
-
-    // after rebalancing, we should have reset to the committed positions
-    assertEquals(300, consumer0.committed(tp).offset)
-    assertEquals(500, consumer0.committed(tp2).offset)
-  }
-
-  @Test
-  def testCommitSpecifiedOffsets() {
-    sendRecords(5, tp)
-    sendRecords(7, tp2)
-
-    this.consumers.head.assign(List(tp, tp2).asJava)
-
-    // Need to poll to join the group
-    this.consumers.head.poll(50)
-    val pos1 = this.consumers.head.position(tp)
-    val pos2 = this.consumers.head.position(tp2)
-    this.consumers.head.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp, new OffsetAndMetadata(3L))).asJava)
-    assertEquals(3, this.consumers.head.committed(tp).offset)
-    assertNull(this.consumers.head.committed(tp2))
-
-    // Positions should not change
-    assertEquals(pos1, this.consumers.head.position(tp))
-    assertEquals(pos2, this.consumers.head.position(tp2))
-    this.consumers.head.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp2, new OffsetAndMetadata(5L))).asJava)
-    assertEquals(3, this.consumers.head.committed(tp).offset)
-    assertEquals(5, this.consumers.head.committed(tp2).offset)
-
-    // Using async should pick up the committed changes after commit completes
-    val commitCallback = new CountConsumerCommitCallback()
-    this.consumers.head.commitAsync(Map[TopicPartition, OffsetAndMetadata]((tp2, new OffsetAndMetadata(7L))).asJava, commitCallback)
-    awaitCommitCallback(this.consumers.head, commitCallback)
-    assertEquals(7, this.consumers.head.committed(tp2).offset)
-  }
-
-  @Test
-  def testListTopics() {
-    val numParts = 2
-    val topic1 = "part-test-topic-1"
-    val topic2 = "part-test-topic-2"
-    val topic3 = "part-test-topic-3"
-    TestUtils.createTopic(this.zkUtils, topic1, numParts, 1, this.servers)
-    TestUtils.createTopic(this.zkUtils, topic2, numParts, 1, this.servers)
-    TestUtils.createTopic(this.zkUtils, topic3, numParts, 1, this.servers)
-
-    val topics = this.consumers.head.listTopics()
-    assertNotNull(topics)
-    assertEquals(5, topics.size())
-    assertEquals(5, topics.keySet().size())
-    assertEquals(2, topics.get(topic1).size)
-    assertEquals(2, topics.get(topic2).size)
-    assertEquals(2, topics.get(topic3).size)
-  }
-
-  @Test
-  def testPartitionReassignmentCallback() {
+  def testCoordinatorFailover() {
     val listener = new TestConsumerReassignmentListener()
-    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test
-    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
+    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "5000")
+    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "2000")
     val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
     consumers += consumer0
 
     consumer0.subscribe(List(topic).asJava, listener)
 
     // the initial subscription should cause a callback execution
-    while (listener.callsToAssigned == 0)
-      consumer0.poll(50)
+    consumer0.poll(2000)
+
+    assertEquals(1, listener.callsToAssigned)
 
     // get metadata for the topic
     var parts: Seq[PartitionInfo] = null
@@ -200,54 +111,13 @@ abstract class BaseConsumerTest extends IntegrationTestHarness with Logging {
     val coordinator = parts.head.leader().id()
     this.servers(coordinator).shutdown()
 
-    // this should cause another callback execution
-    while (listener.callsToAssigned < 2)
-      consumer0.poll(50)
-
-    assertEquals(2, listener.callsToAssigned)
-
-    // only expect one revocation since revoke is not invoked on initial membership
-    assertEquals(2, listener.callsToRevoked)
-  }
-
-  @Test
-  def testUnsubscribeTopic() {
-
-    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test
-    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
-    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
-    consumers += consumer0
-
-    val listener = new TestConsumerReassignmentListener()
-    consumer0.subscribe(List(topic).asJava, listener)
-
-    // the initial subscription should cause a callback execution
-    while (listener.callsToAssigned == 0)
-      consumer0.poll(50)
+    consumer0.poll(5000)
 
-    consumer0.subscribe(List[String]().asJava)
-    assertEquals(0, consumer0.assignment.size())
+    // the failover should not cause a rebalance
+    assertEquals(1, listener.callsToAssigned)
+    assertEquals(1, listener.callsToRevoked)
   }
 
-  @Test
-  def testPauseStateNotPreservedByRebalance() {
-    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test
-    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
-    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
-    consumers += consumer0
-
-    sendRecords(5)
-    consumer0.subscribe(List(topic).asJava)
-    consumeAndVerifyRecords(consumer = consumer0, numRecords = 5, startingOffset = 0)
-    consumer0.pause(List(tp).asJava)
-
-    // subscribe to a new topic to trigger a rebalance
-    consumer0.subscribe(List("topic2").asJava)
-
-    // after rebalance, our position should be reset and our pause state lost,
-    // so we should be able to consume from the beginning
-    consumeAndVerifyRecords(consumer = consumer0, numRecords = 0, startingOffset = 5)
-  }
 
   protected class TestConsumerReassignmentListener extends ConsumerRebalanceListener {
     var callsToAssigned = 0
@@ -394,12 +264,22 @@ abstract class BaseConsumerTest extends IntegrationTestHarness with Logging {
       !subscriptionChanged
     }
 
+    override def initiateShutdown(): Boolean = {
+      val res = super.initiateShutdown()
+      consumer.wakeup()
+      res
+    }
+
     override def doWork(): Unit = {
       if (subscriptionChanged) {
         consumer.subscribe(topicsSubscription.asJava, rebalanceListener)
         subscriptionChanged = false
       }
-      consumer.poll(50)
+      try {
+        consumer.poll(50)
+      } catch {
+        case e: WakeupException => // ignore for shutdown
+      }
     }
   }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala b/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
index 7064052..0900d43 100644
--- a/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
+++ b/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala
@@ -49,8 +49,8 @@ class ConsumerBounceTest extends IntegrationTestHarness with Logging {
   this.producerConfig.setProperty(ProducerConfig.ACKS_CONFIG, "all")
   this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "my-test")
   this.consumerConfig.setProperty(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG, 4096.toString)
-  this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100")
-  this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
+  this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "10000")
+  this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "3000")
   this.consumerConfig.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
 
   override def generateConfigs() = {
@@ -81,14 +81,7 @@ class ConsumerBounceTest extends IntegrationTestHarness with Logging {
     var consumed = 0L
     val consumer = this.consumers.head
 
-    consumer.subscribe(List(topic), new ConsumerRebalanceListener {
-      override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) {
-        // TODO: until KAFKA-2017 is merged, we have to handle the case in which
-        // the commit fails on prior to rebalancing on coordinator fail-over.
-        consumer.seek(tp, consumed)
-      }
-      override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) {}
-    })
+    consumer.subscribe(List(topic))
 
     val scheduler = new BounceBrokerScheduler(numIters)
     scheduler.start()

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
index b1e9676..243f913 100644
--- a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
@@ -58,6 +58,31 @@ class PlaintextConsumerTest extends BaseConsumerTest {
   }
 
   @Test
+  def testMaxPollIntervalMs() {
+    this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, 3000.toString)
+    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 500.toString)
+    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 2000.toString)
+
+    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
+    consumers += consumer0
+
+    val listener = new TestConsumerReassignmentListener()
+    consumer0.subscribe(List(topic).asJava, listener)
+
+    // poll once to get the initial assignment
+    consumer0.poll(0)
+    assertEquals(1, listener.callsToAssigned)
+    assertEquals(1, listener.callsToRevoked)
+
+    Thread.sleep(3500)
+
+    // we should fall out of the group and need to rebalance
+    consumer0.poll(0)
+    assertEquals(2, listener.callsToAssigned)
+    assertEquals(2, listener.callsToRevoked)
+  }
+
+  @Test
   def testAutoCommitOnClose() {
     this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true")
     val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
@@ -593,16 +618,14 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     // create a group of consumers, subscribe the consumers to all the topics and start polling
     // for the topic partition assignment
     val (rrConsumers, consumerPollers) = createConsumerGroupAndWaitForAssignment(10, List(topic1, topic2), subscriptions)
+    try {
+      validateGroupAssignment(consumerPollers, subscriptions, s"Did not get valid initial assignment for partitions ${subscriptions.asJava}")
 
-    // add one more consumer and validate re-assignment
-    addConsumersToGroupAndWaitForGroupAssignment(1, consumers, consumerPollers, List(topic1, topic2), subscriptions)
-
-    // done with pollers and consumers
-    for (poller <- consumerPollers)
-      poller.shutdown()
-
-    for (consumer <- consumers)
-      consumer.unsubscribe()
+      // add one more consumer and validate re-assignment
+      addConsumersToGroupAndWaitForGroupAssignment(1, consumers, consumerPollers, List(topic1, topic2), subscriptions)
+    } finally {
+      consumerPollers.foreach(_.shutdown())
+    }
   }
 
   /**
@@ -618,25 +641,25 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     val subscriptions = Set(tp, tp2) ++ createTopicAndSendRecords(topic1, 5, 100)
 
     // subscribe all consumers to all topics and validate the assignment
-    val consumerPollers = subscribeConsumersAndWaitForAssignment(consumers, List(topic, topic1), subscriptions)
+    val consumerPollers = subscribeConsumers(consumers, List(topic, topic1))
 
-    // add 2 more consumers and validate re-assignment
-    addConsumersToGroupAndWaitForGroupAssignment(2, consumers, consumerPollers, List(topic, topic1), subscriptions)
+    try {
+      validateGroupAssignment(consumerPollers, subscriptions, s"Did not get valid initial assignment for partitions ${subscriptions.asJava}")
 
-    // add one more topic and validate partition re-assignment
-    val topic2 = "topic2"
-    val expandedSubscriptions = subscriptions ++ createTopicAndSendRecords(topic2, 3, 100)
-    changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers, List(topic, topic1, topic2), expandedSubscriptions)
+      // add 2 more consumers and validate re-assignment
+      addConsumersToGroupAndWaitForGroupAssignment(2, consumers, consumerPollers, List(topic, topic1), subscriptions)
 
-    // remove the topic we just added and validate re-assignment
-    changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers, List(topic, topic1), subscriptions)
+      // add one more topic and validate partition re-assignment
+      val topic2 = "topic2"
+      val expandedSubscriptions = subscriptions ++ createTopicAndSendRecords(topic2, 3, 100)
+      changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers, List(topic, topic1, topic2), expandedSubscriptions)
 
-    // done with pollers and consumers
-    for (poller <- consumerPollers)
-      poller.shutdown()
+      // remove the topic we just added and validate re-assignment
+      changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers, List(topic, topic1), subscriptions)
 
-    for (consumer <- consumers)
-      consumer.unsubscribe()
+    } finally {
+      consumerPollers.foreach(_.shutdown())
+    }
   }
 
   @Test
@@ -830,6 +853,138 @@ class PlaintextConsumerTest extends BaseConsumerTest {
       startingTimestamp = startTime, timestampType = TimestampType.LOG_APPEND_TIME)
   }
 
+  @Test
+  def testListTopics() {
+    val numParts = 2
+    val topic1 = "part-test-topic-1"
+    val topic2 = "part-test-topic-2"
+    val topic3 = "part-test-topic-3"
+    TestUtils.createTopic(this.zkUtils, topic1, numParts, 1, this.servers)
+    TestUtils.createTopic(this.zkUtils, topic2, numParts, 1, this.servers)
+    TestUtils.createTopic(this.zkUtils, topic3, numParts, 1, this.servers)
+
+    val topics = this.consumers.head.listTopics()
+    assertNotNull(topics)
+    assertEquals(5, topics.size())
+    assertEquals(5, topics.keySet().size())
+    assertEquals(2, topics.get(topic1).size)
+    assertEquals(2, topics.get(topic2).size)
+    assertEquals(2, topics.get(topic3).size)
+  }
+
+  @Test
+  def testUnsubscribeTopic() {
+    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test
+    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
+    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
+    consumers += consumer0
+
+    val listener = new TestConsumerReassignmentListener()
+    consumer0.subscribe(List(topic).asJava, listener)
+
+    // the initial subscription should cause a callback execution
+    while (listener.callsToAssigned == 0)
+      consumer0.poll(50)
+
+    consumer0.subscribe(List[String]().asJava)
+    assertEquals(0, consumer0.assignment.size())
+  }
+
+  @Test
+  def testPauseStateNotPreservedByRebalance() {
+    this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test
+    this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30")
+    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
+    consumers += consumer0
+
+    sendRecords(5)
+    consumer0.subscribe(List(topic).asJava)
+    consumeAndVerifyRecords(consumer = consumer0, numRecords = 5, startingOffset = 0)
+    consumer0.pause(List(tp).asJava)
+
+    // subscribe to a new topic to trigger a rebalance
+    consumer0.subscribe(List("topic2").asJava)
+
+    // after rebalance, our position should be reset and our pause state lost,
+    // so we should be able to consume from the beginning
+    consumeAndVerifyRecords(consumer = consumer0, numRecords = 0, startingOffset = 5)
+  }
+
+  @Test
+  def testCommitSpecifiedOffsets() {
+    sendRecords(5, tp)
+    sendRecords(7, tp2)
+
+    this.consumers.head.assign(List(tp, tp2).asJava)
+
+    // Need to poll to join the group
+    this.consumers.head.poll(50)
+    val pos1 = this.consumers.head.position(tp)
+    val pos2 = this.consumers.head.position(tp2)
+    this.consumers.head.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp, new OffsetAndMetadata(3L))).asJava)
+    assertEquals(3, this.consumers.head.committed(tp).offset)
+    assertNull(this.consumers.head.committed(tp2))
+
+    // Positions should not change
+    assertEquals(pos1, this.consumers.head.position(tp))
+    assertEquals(pos2, this.consumers.head.position(tp2))
+    this.consumers.head.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp2, new OffsetAndMetadata(5L))).asJava)
+    assertEquals(3, this.consumers.head.committed(tp).offset)
+    assertEquals(5, this.consumers.head.committed(tp2).offset)
+
+    // Using async should pick up the committed changes after commit completes
+    val commitCallback = new CountConsumerCommitCallback()
+    this.consumers.head.commitAsync(Map[TopicPartition, OffsetAndMetadata]((tp2, new OffsetAndMetadata(7L))).asJava, commitCallback)
+    awaitCommitCallback(this.consumers.head, commitCallback)
+    assertEquals(7, this.consumers.head.committed(tp2).offset)
+  }
+
+  @Test
+  def testAutoCommitOnRebalance() {
+    val topic2 = "topic2"
+    TestUtils.createTopic(this.zkUtils, topic2, 2, serverCount, this.servers)
+
+    this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true")
+    val consumer0 = new KafkaConsumer(this.consumerConfig, new ByteArrayDeserializer(), new ByteArrayDeserializer())
+    consumers += consumer0
+
+    val numRecords = 10000
+    sendRecords(numRecords)
+
+    val rebalanceListener = new ConsumerRebalanceListener {
+      override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) = {
+        // keep partitions paused in this test so that we can verify the commits based on specific seeks
+        consumer0.pause(partitions)
+      }
+
+      override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) = {}
+    }
+
+    consumer0.subscribe(List(topic).asJava, rebalanceListener)
+
+    val assignment = Set(tp, tp2)
+    TestUtils.waitUntilTrue(() => {
+      consumer0.poll(50)
+      consumer0.assignment() == assignment.asJava
+    }, s"Expected partitions ${assignment.asJava} but actually got ${consumer0.assignment()}")
+
+    consumer0.seek(tp, 300)
+    consumer0.seek(tp2, 500)
+
+    // change subscription to trigger rebalance
+    consumer0.subscribe(List(topic, topic2).asJava, rebalanceListener)
+
+    val newAssignment = Set(tp, tp2, new TopicPartition(topic2, 0), new TopicPartition(topic2, 1))
+    TestUtils.waitUntilTrue(() => {
+      val records = consumer0.poll(50)
+      consumer0.assignment() == newAssignment.asJava
+    }, s"Expected partitions ${newAssignment.asJava} but actually got ${consumer0.assignment()}")
+
+    // after rebalancing, we should have reset to the committed positions
+    assertEquals(300, consumer0.committed(tp).offset)
+    assertEquals(500, consumer0.committed(tp2).offset)
+  }
+
   def runMultiConsumerSessionTimeoutTest(closeConsumer: Boolean): Unit = {
     // use consumers defined in this class plus one additional consumer
     // Use topic defined in this class + one additional topic
@@ -887,7 +1042,8 @@ class PlaintextConsumerTest extends BaseConsumerTest {
    * Subscribes consumer 'consumer' to a given list of topics 'topicsToSubscribe', creates
    * consumer poller and starts polling.
    * Assumes that the consumer is not subscribed to any topics yet
-   * @param consumer consumer
+    *
+    * @param consumer consumer
    * @param topicsToSubscribe topics that this consumer will subscribe to
    * @return consumer poller for the given consumer
    */
@@ -901,34 +1057,25 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
   /**
    * Creates consumer pollers corresponding to a given consumer group, one per consumer; subscribes consumers to
-   * 'topicsToSubscribe' topics, waits until consumers get topics assignment, and validates the assignment
-   * Currently, assignment validation requires that total number of partitions is greater or equal to
-   * number of consumers (i.e. subscriptions.size >= consumerGroup.size)
-   * Assumes that topics are already created with partitions corresponding to a given set of topic partitions ('subscriptions')
+   * 'topicsToSubscribe' topics, waits until consumers get topics assignment.
    *
    * When the function returns, consumer pollers will continue to poll until shutdown is called on every poller.
    *
    * @param consumerGroup consumer group
    * @param topicsToSubscribe topics to which consumers will subscribe to
-   * @param subscriptions set of all topic partitions
    * @return collection of consumer pollers
    */
-  def subscribeConsumersAndWaitForAssignment(consumerGroup: Buffer[KafkaConsumer[Array[Byte], Array[Byte]]],
-                                             topicsToSubscribe: List[String],
-                                             subscriptions: Set[TopicPartition]): Buffer[ConsumerAssignmentPoller] = {
+  def subscribeConsumers(consumerGroup: Buffer[KafkaConsumer[Array[Byte], Array[Byte]]],
+                         topicsToSubscribe: List[String]): Buffer[ConsumerAssignmentPoller] = {
     val consumerPollers = Buffer[ConsumerAssignmentPoller]()
     for (consumer <- consumerGroup)
       consumerPollers += subscribeConsumerAndStartPolling(consumer, topicsToSubscribe)
-    validateGroupAssignment(consumerPollers, subscriptions, s"Did not get valid initial assignment for partitions ${subscriptions.asJava}")
     consumerPollers
   }
 
   /**
    * Creates 'consumerCount' consumers and consumer pollers, one per consumer; subscribes consumers to
-   * 'topicsToSubscribe' topics, waits until consumers get topics assignment, and validates the assignment
-   * Currently, assignment validation requires that total number of partitions is greater or equal to
-   * number of consumers (i.e. subscriptions.size >= consumerCount)
-   * Assumes that topics are already created with partitions corresponding to a given set of topic partitions ('subscriptions')
+   * 'topicsToSubscribe' topics, waits until consumers get topics assignment.
    *
    * When the function returns, consumer pollers will continue to poll until shutdown is called on every poller.
    *
@@ -947,7 +1094,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumers ++= consumerGroup
 
     // create consumer pollers, wait for assignment and validate it
-    val consumerPollers = subscribeConsumersAndWaitForAssignment(consumerGroup, topicsToSubscribe, subscriptions)
+    val consumerPollers = subscribeConsumers(consumerGroup, topicsToSubscribe)
 
     (consumerGroup, consumerPollers)
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
index 63636c0..591479e 100644
--- a/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
@@ -16,7 +16,6 @@
   */
 package kafka.api
 
-import kafka.server.KafkaConfig
 import org.apache.kafka.common.protocol.SecurityProtocol
 
 class SaslPlainSslEndToEndAuthorizationTest extends EndToEndAuthorizationTest {

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/unit/kafka/coordinator/GroupCoordinatorResponseTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/GroupCoordinatorResponseTest.scala b/core/src/test/scala/unit/kafka/coordinator/GroupCoordinatorResponseTest.scala
index c917ca4..a981e68 100644
--- a/core/src/test/scala/unit/kafka/coordinator/GroupCoordinatorResponseTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/GroupCoordinatorResponseTest.scala
@@ -54,6 +54,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   val ClientHost = "localhost"
   val ConsumerMinSessionTimeout = 10
   val ConsumerMaxSessionTimeout = 1000
+  val DefaultRebalanceTimeout = 500
   val DefaultSessionTimeout = 500
   var timer: MockTimer = null
   var groupCoordinator: GroupCoordinator = null
@@ -113,7 +114,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testJoinGroupWrongCoordinator() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(otherGroupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(otherGroupId, memberId, protocolType, protocols)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NOT_COORDINATOR_FOR_GROUP.code, joinGroupErrorCode)
   }
@@ -122,7 +123,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testJoinGroupSessionTimeoutTooSmall() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, ConsumerMinSessionTimeout - 1, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols, sessionTimeout = ConsumerMinSessionTimeout - 1)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.INVALID_SESSION_TIMEOUT.code, joinGroupErrorCode)
   }
@@ -131,14 +132,14 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testJoinGroupSessionTimeoutTooLarge() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, ConsumerMaxSessionTimeout + 1, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols, sessionTimeout = ConsumerMaxSessionTimeout + 1)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.INVALID_SESSION_TIMEOUT.code, joinGroupErrorCode)
   }
 
   @Test
   def testJoinGroupUnknownConsumerNewGroup() {
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.UNKNOWN_MEMBER_ID.code, joinGroupErrorCode)
   }
@@ -148,7 +149,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val groupId = ""
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     assertEquals(Errors.INVALID_GROUP_ID.code, joinGroupResult.errorCode)
   }
 
@@ -156,8 +157,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testValidJoinGroup() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType,
-      protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
   }
@@ -167,12 +167,11 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
     val otherMemberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType,
-      protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
 
     EasyMock.reset(replicaManager)
-    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, DefaultSessionTimeout, "connect", protocols)
+    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, "connect", protocols)
     assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL.code, otherJoinGroupResult.errorCode)
   }
 
@@ -182,12 +181,11 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
     val otherMemberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, List(("range", metadata)))
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, List(("range", metadata)))
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
 
     EasyMock.reset(replicaManager)
-    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, DefaultSessionTimeout, protocolType,
-      List(("roundrobin", metadata)))
+    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, protocolType, List(("roundrobin", metadata)))
     assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL.code, otherJoinGroupResult.errorCode)
   }
 
@@ -196,11 +194,11 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
     val otherMemberId = "memberId"
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
 
     EasyMock.reset(replicaManager)
-    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val otherJoinGroupResult = joinGroup(groupId, otherMemberId, protocolType, protocols)
     assertEquals(Errors.UNKNOWN_MEMBER_ID.code, otherJoinGroupResult.errorCode)
   }
 
@@ -223,7 +221,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
     val otherMemberId = "memberId"
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
@@ -242,7 +240,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testHeartbeatRebalanceInProgress() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
@@ -256,7 +254,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testHeartbeatIllegalGeneration() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
@@ -275,7 +273,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testValidHeartbeat() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -295,7 +293,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testSessionTimeout() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -322,7 +320,8 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
     val sessionTimeout = 1000
 
-    val joinGroupResult = joinGroup(groupId, memberId, sessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols,
+      rebalanceTimeout = sessionTimeout, sessionTimeout = sessionTimeout)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -352,7 +351,8 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val tp = new TopicPartition("topic", 0)
     val offset = OffsetAndMetadata(0)
 
-    val joinGroupResult = joinGroup(groupId, memberId, sessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols,
+      rebalanceTimeout = sessionTimeout, sessionTimeout = sessionTimeout)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -376,10 +376,82 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   }
 
   @Test
+  def testSessionTimeoutDuringRebalance() {
+    // create a group with a single member
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols,
+      rebalanceTimeout = 2000, sessionTimeout = 1000)
+    val firstMemberId = firstJoinResult.memberId
+    val firstGenerationId = firstJoinResult.generationId
+    assertEquals(firstMemberId, firstJoinResult.leaderId)
+    assertEquals(Errors.NONE.code, firstJoinResult.errorCode)
+
+    EasyMock.reset(replicaManager)
+    val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]()))
+    assertEquals(Errors.NONE.code, firstSyncResult._2)
+
+    // now have a new member join to trigger a rebalance
+    EasyMock.reset(replicaManager)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
+
+    timer.advanceClock(500)
+
+    EasyMock.reset(replicaManager)
+    var heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId)
+    assertEquals(Errors.REBALANCE_IN_PROGRESS.code, heartbeatResult)
+
+    // letting the session expire should make the member fall out of the group
+    timer.advanceClock(1100)
+
+    EasyMock.reset(replicaManager)
+    heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId)
+    assertEquals(Errors.UNKNOWN_MEMBER_ID.code, heartbeatResult)
+
+    // and the rebalance should complete with only the new member
+    val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
+    assertEquals(Errors.NONE.code, otherJoinResult.errorCode)
+  }
+
+  @Test
+  def testRebalanceCompletesBeforeMemberJoins() {
+    // create a group with a single member
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols,
+      rebalanceTimeout = 1200, sessionTimeout = 1000)
+    val firstMemberId = firstJoinResult.memberId
+    val firstGenerationId = firstJoinResult.generationId
+    assertEquals(firstMemberId, firstJoinResult.leaderId)
+    assertEquals(Errors.NONE.code, firstJoinResult.errorCode)
+
+    EasyMock.reset(replicaManager)
+    val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]()))
+    assertEquals(Errors.NONE.code, firstSyncResult._2)
+
+    // now have a new member join to trigger a rebalance
+    EasyMock.reset(replicaManager)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
+
+    // send a couple heartbeats to keep the member alive while the rebalance finishes
+    timer.advanceClock(500)
+    EasyMock.reset(replicaManager)
+    var heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId)
+    assertEquals(Errors.REBALANCE_IN_PROGRESS.code, heartbeatResult)
+
+    timer.advanceClock(500)
+    EasyMock.reset(replicaManager)
+    heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId)
+    assertEquals(Errors.REBALANCE_IN_PROGRESS.code, heartbeatResult)
+
+    // now timeout the rebalance, which should kick the unjoined member out of the group
+    // and let the rebalance finish with only the new member
+    timer.advanceClock(500)
+    val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
+    assertEquals(Errors.NONE.code, otherJoinResult.errorCode)
+  }
+
+  @Test
   def testSyncGroupEmptyAssignment() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -416,7 +488,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testSyncGroupFromUnknownMember() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
@@ -436,7 +508,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testSyncGroupFromIllegalGeneration() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
@@ -453,8 +525,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     // 1. join and sync with a single member (because we can't immediately join with two members)
     // 2. join and sync with the first member and a new member
 
-    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val firstMemberId = firstJoinResult.memberId
     val firstGenerationId = firstJoinResult.generationId
     assertEquals(firstMemberId, firstJoinResult.leaderId)
@@ -465,11 +536,10 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     assertEquals(Errors.NONE.code, firstSyncResult._2)
 
     EasyMock.reset(replicaManager)
-    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
 
     EasyMock.reset(replicaManager)
-    val joinFuture = sendJoinGroup(groupId, firstMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols)
 
     val joinResult = await(joinFuture, DefaultSessionTimeout+100)
     val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
@@ -484,7 +554,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
     // this shouldn't cause a rebalance since protocol information hasn't changed
     EasyMock.reset(replicaManager)
-    val followerJoinResult = joinGroup(groupId, otherJoinResult.memberId, DefaultSessionTimeout, protocolType, protocols)
+    val followerJoinResult = joinGroup(groupId, otherJoinResult.memberId, protocolType, protocols)
 
     assertEquals(Errors.NONE.code, followerJoinResult.errorCode)
     assertEquals(nextGenerationId, followerJoinResult.generationId)
@@ -492,8 +562,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
   @Test
   def testJoinGroupFromUnchangedLeaderShouldRebalance() {
-    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val firstMemberId = firstJoinResult.memberId
     val firstGenerationId = firstJoinResult.generationId
     assertEquals(firstMemberId, firstJoinResult.leaderId)
@@ -507,7 +576,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     // leader to push new assignments when local metadata changes
 
     EasyMock.reset(replicaManager)
-    val secondJoinResult = joinGroup(groupId, firstMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val secondJoinResult = joinGroup(groupId, firstMemberId, protocolType, protocols)
 
     assertEquals(Errors.NONE.code, secondJoinResult.errorCode)
     assertNotEquals(firstGenerationId, secondJoinResult.generationId)
@@ -519,8 +588,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     // 1. join and sync with a single member (because we can't immediately join with two members)
     // 2. join and sync with the first member and a new member
 
-    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val firstMemberId = firstJoinResult.memberId
     val firstGenerationId = firstJoinResult.generationId
     assertEquals(firstMemberId, firstJoinResult.leaderId)
@@ -531,11 +599,10 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     assertEquals(Errors.NONE.code, firstSyncResult._2)
 
     EasyMock.reset(replicaManager)
-    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
 
     EasyMock.reset(replicaManager)
-    val joinFuture = sendJoinGroup(groupId, firstMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols)
 
     val joinResult = await(joinFuture, DefaultSessionTimeout+100)
     val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
@@ -565,8 +632,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     // 1. join and sync with a single member (because we can't immediately join with two members)
     // 2. join and sync with the first member and a new member
 
-    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val firstJoinResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val firstMemberId = firstJoinResult.memberId
     val firstGenerationId = firstJoinResult.generationId
     assertEquals(firstMemberId, firstJoinResult.leaderId)
@@ -577,11 +643,10 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     assertEquals(Errors.NONE.code, firstSyncResult._2)
 
     EasyMock.reset(replicaManager)
-    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
 
     EasyMock.reset(replicaManager)
-    val joinFuture = sendJoinGroup(groupId, firstMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols)
 
     val joinResult = await(joinFuture, DefaultSessionTimeout+100)
     val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
@@ -616,8 +681,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     // 1. join and sync with a single member (because we can't immediately join with two members)
     // 2. join and sync with the first member and a new member
 
-    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val firstMemberId = joinGroupResult.memberId
     val firstGenerationId = joinGroupResult.generationId
     assertEquals(firstMemberId, joinGroupResult.leaderId)
@@ -629,11 +693,10 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     assertEquals(Errors.NONE.code, syncGroupErrorCode)
 
     EasyMock.reset(replicaManager)
-    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
 
     EasyMock.reset(replicaManager)
-    val joinFuture = sendJoinGroup(groupId, firstMemberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols)
 
     val joinResult = await(joinFuture, DefaultSessionTimeout+100)
     val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100)
@@ -690,7 +753,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val tp = new TopicPartition("topic", 0)
     val offset = OffsetAndMetadata(0)
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -704,8 +767,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   @Test
   def testHeartbeatDuringRebalanceCausesRebalanceInProgress() {
     // First start up a group (with a slightly larger timeout to give us time to heartbeat when the rebalance starts)
-    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout,
-      protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val assignedConsumerId = joinGroupResult.memberId
     val initialGenerationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
@@ -713,7 +775,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
     // Then join with a new consumer to trigger a rebalance
     EasyMock.reset(replicaManager)
-    sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout, protocolType, protocols)
+    sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
 
     // We should be in the middle of a rebalance, so the heartbeat should return rebalance in progress
     EasyMock.reset(replicaManager)
@@ -723,7 +785,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
   @Test
   def testGenerationIdIncrementsOnRebalance() {
-    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
     val initialGenerationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
     val memberId = joinGroupResult.memberId
@@ -736,7 +798,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     assertEquals(Errors.NONE.code, syncGroupErrorCode)
 
     EasyMock.reset(replicaManager)
-    val otherJoinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val otherJoinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val nextGenerationId = otherJoinGroupResult.generationId
     val otherJoinGroupErrorCode = otherJoinGroupResult.errorCode
     assertEquals(2, nextGenerationId)
@@ -763,7 +825,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
     val otherMemberId = "consumerId"
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
 
@@ -776,7 +838,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   def testValidLeaveGroup() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
 
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
@@ -789,7 +851,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   @Test
   def testListGroupsIncludesStableGroups() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
@@ -808,7 +870,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   @Test
   def testListGroupsIncludesRebalancingGroups() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     assertEquals(Errors.NONE.code, joinGroupResult.errorCode)
 
     val (error, groups) = groupCoordinator.handleListGroups()
@@ -835,14 +897,15 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   @Test
   def testDescribeGroupStable() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val assignedMemberId = joinGroupResult.memberId
     val generationId = joinGroupResult.generationId
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
 
     EasyMock.reset(replicaManager)
-    val syncGroupResult = syncGroupLeader(groupId, generationId, assignedMemberId,  Map(assignedMemberId -> Array[Byte]()))
+    val syncGroupResult = syncGroupLeader(groupId, generationId, assignedMemberId, Map(assignedMemberId -> Array[Byte]()))
+
     val syncGroupErrorCode = syncGroupResult._2
     assertEquals(Errors.NONE.code, syncGroupErrorCode)
 
@@ -857,7 +920,7 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
   @Test
   def testDescribeGroupRebalancing() {
     val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID
-    val joinGroupResult = joinGroup(groupId, memberId, DefaultSessionTimeout, protocolType, protocols)
+    val joinGroupResult = joinGroup(groupId, memberId, protocolType, protocols)
     val joinGroupErrorCode = joinGroupResult.errorCode
     assertEquals(Errors.NONE.code, joinGroupErrorCode)
 
@@ -903,14 +966,15 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
   private def sendJoinGroup(groupId: String,
                             memberId: String,
-                            sessionTimeout: Int,
                             protocolType: String,
-                            protocols: List[(String, Array[Byte])]): Future[JoinGroupResult] = {
+                            protocols: List[(String, Array[Byte])],
+                            rebalanceTimeout: Int = DefaultRebalanceTimeout,
+                            sessionTimeout: Int = DefaultSessionTimeout): Future[JoinGroupResult] = {
     val (responseFuture, responseCallback) = setupJoinGroupCallback
 
     EasyMock.replay(replicaManager)
 
-    groupCoordinator.handleJoinGroup(groupId, memberId, "clientId", "clientHost", sessionTimeout,
+    groupCoordinator.handleJoinGroup(groupId, memberId, "clientId", "clientHost", rebalanceTimeout, sessionTimeout,
       protocolType, protocols, responseCallback)
     responseFuture
   }
@@ -954,29 +1018,32 @@ class GroupCoordinatorResponseTest extends JUnitSuite {
 
   private def joinGroup(groupId: String,
                         memberId: String,
-                        sessionTimeout: Int,
                         protocolType: String,
-                        protocols: List[(String, Array[Byte])]): JoinGroupResult = {
-    val responseFuture = sendJoinGroup(groupId, memberId, sessionTimeout, protocolType, protocols)
+                        protocols: List[(String, Array[Byte])],
+                        sessionTimeout: Int = DefaultSessionTimeout,
+                        rebalanceTimeout: Int = DefaultRebalanceTimeout): JoinGroupResult = {
+    val responseFuture = sendJoinGroup(groupId, memberId, protocolType, protocols, rebalanceTimeout, sessionTimeout)
     timer.advanceClock(10)
     // should only have to wait as long as session timeout, but allow some extra time in case of an unexpected delay
-    Await.result(responseFuture, Duration(sessionTimeout+100, TimeUnit.MILLISECONDS))
+    Await.result(responseFuture, Duration(rebalanceTimeout + 100, TimeUnit.MILLISECONDS))
   }
 
 
   private def syncGroupFollower(groupId: String,
                                 generationId: Int,
-                                memberId: String): SyncGroupCallbackParams = {
+                                memberId: String,
+                                sessionTimeout: Int = DefaultSessionTimeout): SyncGroupCallbackParams = {
     val responseFuture = sendSyncGroupFollower(groupId, generationId, memberId)
-    Await.result(responseFuture, Duration(DefaultSessionTimeout+100, TimeUnit.MILLISECONDS))
+    Await.result(responseFuture, Duration(sessionTimeout + 100, TimeUnit.MILLISECONDS))
   }
 
   private def syncGroupLeader(groupId: String,
                               generationId: Int,
                               memberId: String,
-                              assignment: Map[String, Array[Byte]]): SyncGroupCallbackParams = {
+                              assignment: Map[String, Array[Byte]],
+                              sessionTimeout: Int = DefaultSessionTimeout): SyncGroupCallbackParams = {
     val responseFuture = sendSyncGroupLeader(groupId, generationId, memberId, assignment)
-    Await.result(responseFuture, Duration(DefaultSessionTimeout+100, TimeUnit.MILLISECONDS))
+    Await.result(responseFuture, Duration(sessionTimeout + 100, TimeUnit.MILLISECONDS))
   }
 
   private def heartbeat(groupId: String,

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/unit/kafka/coordinator/GroupMetadataManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/GroupMetadataManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/GroupMetadataManagerTest.scala
index b9569ca..b4f9ba3 100644
--- a/core/src/test/scala/unit/kafka/coordinator/GroupMetadataManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/GroupMetadataManagerTest.scala
@@ -17,6 +17,7 @@
 
 package kafka.coordinator
 
+import kafka.api.ApiVersion
 import kafka.cluster.Partition
 import kafka.common.{OffsetAndMetadata, Topic}
 import kafka.log.LogAppendInfo
@@ -46,7 +47,8 @@ class GroupMetadataManagerTest {
   val groupId = "foo"
   val groupPartitionId = 0
   val protocolType = "protocolType"
-  val sessionTimeout = 30000
+  val rebalanceTimeout = 60000
+  val sessionTimeout = 10000
 
 
   @Before
@@ -74,9 +76,8 @@ class GroupMetadataManagerTest {
 
     time = new MockTime
     replicaManager = EasyMock.createNiceMock(classOf[ReplicaManager])
-    groupMetadataManager = new GroupMetadataManager(0, offsetConfig, replicaManager, zkUtils, time)
+    groupMetadataManager = new GroupMetadataManager(0, ApiVersion.latestVersion, offsetConfig, replicaManager, zkUtils, time)
     partition = EasyMock.niceMock(classOf[Partition])
-
   }
 
   @After
@@ -119,7 +120,7 @@ class GroupMetadataManagerTest {
     val group = new GroupMetadata(groupId)
     groupMetadataManager.addGroup(group)
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeout,
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeout, sessionTimeout,
       protocolType, List(("protocol", Array[Byte]())))
     member.awaitingJoinCallback = (joinGroupResult: JoinGroupResult) => {}
     group.add(memberId, member)
@@ -337,7 +338,7 @@ class GroupMetadataManagerTest {
     val group = new GroupMetadata(groupId)
     groupMetadataManager.addGroup(group)
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeout,
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeout, sessionTimeout,
       protocolType, List(("protocol", Array[Byte]())))
     member.awaitingJoinCallback = (joinGroupResult: JoinGroupResult) => {}
     group.add(memberId, member)

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/unit/kafka/coordinator/GroupMetadataTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/GroupMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/GroupMetadataTest.scala
index 18dd143..8539340 100644
--- a/core/src/test/scala/unit/kafka/coordinator/GroupMetadataTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/GroupMetadataTest.scala
@@ -27,7 +27,14 @@ import org.scalatest.junit.JUnitSuite
  * Test group state transitions and other GroupMetadata functionality
  */
 class GroupMetadataTest extends JUnitSuite {
-  var group: GroupMetadata = null
+  private val protocolType = "consumer"
+  private val groupId = "groupId"
+  private val clientId = "clientId"
+  private val clientHost = "clientHost"
+  private val rebalanceTimeoutMs = 60000
+  private val sessionTimeoutMs = 10000
+
+  private var group: GroupMetadata = null
 
   @Before
   def setUp() {
@@ -169,30 +176,24 @@ class GroupMetadataTest extends JUnitSuite {
 
   @Test
   def testSelectProtocol() {
-    val protocolType = "consumer"
-    val groupId = "groupId"
-    val clientId = "clientId"
-    val clientHost = "clientHost"
-    val sessionTimeoutMs = 10000
-
     val memberId = "memberId"
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs,
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
       protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))
 
     group.add(memberId, member)
     assertEquals("range", group.selectProtocol)
 
     val otherMemberId = "otherMemberId"
-    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, sessionTimeoutMs,
-      protocolType, List(("roundrobin", Array.empty[Byte]), ("range", Array.empty[Byte])))
+    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("range", Array.empty[Byte])))
 
     group.add(otherMemberId, otherMember)
     // now could be either range or robin since there is no majority preference
     assertTrue(Set("range", "roundrobin")(group.selectProtocol))
 
     val lastMemberId = "lastMemberId"
-    val lastMember = new MemberMetadata(lastMemberId, groupId, clientId, clientHost, sessionTimeoutMs,
-      protocolType, List(("roundrobin", Array.empty[Byte]), ("range", Array.empty[Byte])))
+    val lastMember = new MemberMetadata(lastMemberId, groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("range", Array.empty[Byte])))
 
     group.add(lastMemberId, lastMember)
     // now we should prefer 'roundrobin'
@@ -207,19 +208,13 @@ class GroupMetadataTest extends JUnitSuite {
 
   @Test
   def testSelectProtocolChoosesCompatibleProtocol() {
-    val protocolType = "consumer"
-    val groupId = "groupId"
-    val clientId = "clientId"
-    val clientHost = "clientHost"
-    val sessionTimeoutMs = 10000
-
     val memberId = "memberId"
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs,
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
       protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))
 
     val otherMemberId = "otherMemberId"
-    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, sessionTimeoutMs,
-      protocolType, List(("roundrobin", Array.empty[Byte]), ("blah", Array.empty[Byte])))
+    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("blah", Array.empty[Byte])))
 
     group.add(memberId, member)
     group.add(otherMemberId, otherMember)
@@ -228,18 +223,12 @@ class GroupMetadataTest extends JUnitSuite {
 
   @Test
   def testSupportsProtocols() {
-    val protocolType = "consumer"
-    val groupId = "groupId"
-    val clientId = "clientId"
-    val clientHost = "clientHost"
-    val sessionTimeoutMs = 10000
-
     // by default, the group supports everything
     assertTrue(group.supportsProtocols(Set("roundrobin", "range")))
 
     val memberId = "memberId"
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs,
-      protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))
 
     group.add(memberId, member)
     assertTrue(group.supportsProtocols(Set("roundrobin", "foo")))
@@ -247,8 +236,8 @@ class GroupMetadataTest extends JUnitSuite {
     assertFalse(group.supportsProtocols(Set("foo", "bar")))
 
     val otherMemberId = "otherMemberId"
-    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, sessionTimeoutMs,
-      protocolType, List(("roundrobin", Array.empty[Byte]), ("blah", Array.empty[Byte])))
+    val otherMember = new MemberMetadata(otherMemberId, groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("blah", Array.empty[Byte])))
 
     group.add(otherMemberId, otherMember)
 
@@ -258,14 +247,8 @@ class GroupMetadataTest extends JUnitSuite {
 
   @Test
   def testInitNextGeneration() {
-    val protocolType = "consumer"
-    val groupId = "groupId"
-    val clientId = "clientId"
-    val clientHost = "clientHost"
-    val sessionTimeoutMs = 10000
     val memberId = "memberId"
-
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs,
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
       protocolType, List(("roundrobin", Array.empty[Byte])))
 
     group.transitionTo(PreparingRebalance)

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/unit/kafka/coordinator/MemberMetadataTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/coordinator/MemberMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/MemberMetadataTest.scala
index 0688424..257dde7 100644
--- a/core/src/test/scala/unit/kafka/coordinator/MemberMetadataTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/MemberMetadataTest.scala
@@ -28,6 +28,7 @@ class MemberMetadataTest extends JUnitSuite {
   val clientHost = "clientHost"
   val memberId = "memberId"
   val protocolType = "consumer"
+  val rebalanceTimeoutMs = 60000
   val sessionTimeoutMs = 10000
 
 
@@ -35,7 +36,8 @@ class MemberMetadataTest extends JUnitSuite {
   def testMatchesSupportedProtocols {
     val protocols = List(("range", Array.empty[Byte]))
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, protocols)
     assertTrue(member.matches(protocols))
     assertFalse(member.matches(List(("range", Array[Byte](0)))))
     assertFalse(member.matches(List(("roundrobin", Array.empty[Byte]))))
@@ -46,7 +48,8 @@ class MemberMetadataTest extends JUnitSuite {
   def testVoteForPreferredProtocol {
     val protocols = List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, protocols)
     assertEquals("range", member.vote(Set("range", "roundrobin")))
     assertEquals("roundrobin", member.vote(Set("blah", "roundrobin")))
   }
@@ -55,7 +58,8 @@ class MemberMetadataTest extends JUnitSuite {
   def testMetadata {
     val protocols = List(("range", Array[Byte](0)), ("roundrobin", Array[Byte](1)))
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, protocols)
     assertTrue(util.Arrays.equals(Array[Byte](0), member.metadata("range")))
     assertTrue(util.Arrays.equals(Array[Byte](1), member.metadata("roundrobin")))
   }
@@ -64,7 +68,8 @@ class MemberMetadataTest extends JUnitSuite {
   def testMetadataRaisesOnUnsupportedProtocol {
     val protocols = List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, protocols)
     member.metadata("blah")
     fail()
   }
@@ -73,7 +78,8 @@ class MemberMetadataTest extends JUnitSuite {
   def testVoteRaisesOnNoSupportedProtocols {
     val protocols = List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))
 
-    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, protocols)
     member.vote(Set("blah"))
     fail()
   }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala b/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
index d18a060..e4ac4fa 100644
--- a/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
+++ b/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
 class MockTimer extends Timer {
 
   val time = new MockTime
-  private val taskQueue = mutable.PriorityQueue[TimerTaskEntry]()
+  private val taskQueue = mutable.PriorityQueue[TimerTaskEntry]()(Ordering[TimerTaskEntry].reverse)
 
   def add(timerTask: TimerTask) {
     if (timerTask.delayMs <= 0)


[2/4] kafka git commit: KAFKA-3888: send consumer heartbeats from a background thread (KIP-62)

Posted by gu...@apache.org.
http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
index 176571c..8ec8b75 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
@@ -26,13 +26,13 @@ import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetCommitCallback;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.clients.consumer.RangeAssignor;
+import org.apache.kafka.clients.consumer.RetriableCommitFailedException;
 import org.apache.kafka.clients.consumer.RoundRobinAssignor;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.ApiException;
-import org.apache.kafka.clients.consumer.RetriableCommitFailedException;
 import org.apache.kafka.common.errors.DisconnectException;
 import org.apache.kafka.common.errors.GroupAuthorizationException;
 import org.apache.kafka.common.errors.OffsetMetadataTooLarge;
@@ -79,11 +79,12 @@ public class ConsumerCoordinatorTest {
     private String topicName = "test";
     private String groupId = "test-group";
     private TopicPartition tp = new TopicPartition(topicName, 0);
+    private int rebalanceTimeoutMs = 60000;
     private int sessionTimeoutMs = 10000;
     private int heartbeatIntervalMs = 5000;
     private long retryBackoffMs = 100;
     private boolean autoCommitEnabled = false;
-    private long autoCommitIntervalMs = 2000;
+    private int autoCommitIntervalMs = 2000;
     private MockPartitionAssignor partitionAssignor = new MockPartitionAssignor();
     private List<PartitionAssignor> assignors = Collections.<PartitionAssignor>singletonList(partitionAssignor);
     private MockTime time;
@@ -123,7 +124,7 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testNormalHeartbeat() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // normal heartbeat
@@ -141,7 +142,7 @@ public class ConsumerCoordinatorTest {
 
     @Test(expected = GroupAuthorizationException.class)
     public void testGroupDescribeUnauthorized() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.GROUP_AUTHORIZATION_FAILED.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.GROUP_AUTHORIZATION_FAILED.code()));
         coordinator.ensureCoordinatorReady();
     }
 
@@ -149,17 +150,17 @@ public class ConsumerCoordinatorTest {
     public void testGroupReadUnauthorized() {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(joinGroupLeaderResponse(0, "memberId", Collections.<String, List<String>>emptyMap(),
                 Errors.GROUP_AUTHORIZATION_FAILED.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
     }
 
     @Test
     public void testCoordinatorNotAvailable() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // GROUP_COORDINATOR_NOT_AVAILABLE will mark coordinator as unknown
@@ -180,7 +181,7 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testNotCoordinator() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // not_coordinator will mark coordinator as unknown
@@ -201,7 +202,7 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testIllegalGeneration() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // illegal_generation will cause re-partition
@@ -225,7 +226,7 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testUnknownConsumerId() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // illegal_generation will cause re-partition
@@ -249,7 +250,7 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testCoordinatorDisconnect() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // coordinator disconnect will mark coordinator as unknown
@@ -279,12 +280,12 @@ public class ConsumerCoordinatorTest {
         metadata.setTopics(Arrays.asList(topicName));
         metadata.update(cluster, time.milliseconds());
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(joinGroupLeaderResponse(0, consumerId, Collections.<String, List<String>>emptyMap(),
                 Errors.INVALID_GROUP_ID.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
     }
 
     @Test
@@ -298,7 +299,7 @@ public class ConsumerCoordinatorTest {
         metadata.setTopics(Arrays.asList(topicName));
         metadata.update(cluster, time.milliseconds());
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // normal join group
@@ -315,7 +316,7 @@ public class ConsumerCoordinatorTest {
                         sync.groupAssignment().containsKey(consumerId);
             }
         }, syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         assertFalse(subscriptions.partitionAssignmentNeeded());
         assertEquals(Collections.singleton(tp), subscriptions.assignedPartitions());
@@ -336,7 +337,7 @@ public class ConsumerCoordinatorTest {
         metadata.setTopics(Arrays.asList(topicName));
         metadata.update(cluster, time.milliseconds());
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         Map<String, List<String>> memberSubscriptions = Collections.singletonMap(consumerId, Arrays.asList(topicName));
@@ -347,14 +348,14 @@ public class ConsumerCoordinatorTest {
         consumerClient.wakeup();
 
         try {
-            coordinator.ensurePartitionAssignment();
+            coordinator.poll(time.milliseconds());
         } catch (WakeupException e) {
             // ignore
         }
 
         // now complete the second half
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         assertFalse(subscriptions.partitionAssignmentNeeded());
         assertEquals(Collections.singleton(tp), subscriptions.assignedPartitions());
@@ -371,7 +372,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // normal join group
@@ -386,7 +387,7 @@ public class ConsumerCoordinatorTest {
             }
         }, syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
 
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         assertFalse(subscriptions.partitionAssignmentNeeded());
         assertEquals(Collections.singleton(tp), subscriptions.assignedPartitions());
@@ -402,12 +403,12 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         final AtomicBoolean received = new AtomicBoolean(false);
         client.prepareResponse(new MockClient.RequestMatcher() {
@@ -430,12 +431,12 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         final AtomicBoolean received = new AtomicBoolean(false);
         client.prepareResponse(new MockClient.RequestMatcher() {
@@ -449,8 +450,9 @@ public class ConsumerCoordinatorTest {
         }, new LeaveGroupResponse(Errors.NONE.code()).toStruct());
         coordinator.maybeLeaveGroup();
         assertTrue(received.get());
-        assertEquals(JoinGroupRequest.UNKNOWN_MEMBER_ID, coordinator.memberId);
-        assertEquals(OffsetCommitRequest.DEFAULT_GENERATION_ID, coordinator.generation);
+
+        AbstractCoordinator.Generation generation = coordinator.generation();
+        assertNull(generation);
     }
 
     @Test(expected = KafkaException.class)
@@ -460,13 +462,13 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // join initially, but let coordinator rebalance on sync
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Collections.<TopicPartition>emptyList(), Errors.UNKNOWN.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
     }
 
     @Test
@@ -476,7 +478,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // join initially, but let coordinator returns unknown member id
@@ -493,7 +495,7 @@ public class ConsumerCoordinatorTest {
         }, joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
 
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         assertFalse(subscriptions.partitionAssignmentNeeded());
         assertEquals(Collections.singleton(tp), subscriptions.assignedPartitions());
@@ -506,7 +508,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // join initially, but let coordinator rebalance on sync
@@ -517,7 +519,7 @@ public class ConsumerCoordinatorTest {
         client.prepareResponse(joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
 
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         assertFalse(subscriptions.partitionAssignmentNeeded());
         assertEquals(Collections.singleton(tp), subscriptions.assignedPartitions());
@@ -530,7 +532,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // join initially, but let coordinator rebalance on sync
@@ -547,7 +549,7 @@ public class ConsumerCoordinatorTest {
         }, joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
 
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         assertFalse(subscriptions.partitionAssignmentNeeded());
         assertEquals(Collections.singleton(tp), subscriptions.assignedPartitions());
@@ -560,13 +562,13 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
 
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         assertFalse(subscriptions.partitionAssignmentNeeded());
 
@@ -595,7 +597,7 @@ public class ConsumerCoordinatorTest {
         // we only have metadata for one topic initially
         metadata.update(TestUtils.singletonCluster(topic1, 1), time.milliseconds());
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // prepare initial rebalance
@@ -625,7 +627,7 @@ public class ConsumerCoordinatorTest {
         client.prepareResponse(joinGroupLeaderResponse(2, consumerId, memberSubscriptions, Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp1, tp2), Errors.NONE.code()));
 
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         assertFalse(subscriptions.partitionAssignmentNeeded());
         assertEquals(new HashSet<>(Arrays.asList(tp1, tp2)), subscriptions.assignedPartitions());
@@ -656,13 +658,13 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // join the group once
         client.prepareResponse(joinGroupFollowerResponse(1, "consumer", "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         assertEquals(1, rebalanceListener.revokedCount);
         assertEquals(1, rebalanceListener.assignedCount);
@@ -671,7 +673,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.needReassignment();
         client.prepareResponse(joinGroupFollowerResponse(2, "consumer", "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         assertEquals(2, rebalanceListener.revokedCount);
         assertEquals(Collections.singleton(tp), rebalanceListener.revoked);
@@ -684,15 +686,15 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // disconnected from original coordinator will cause re-discover and join again
         client.prepareResponse(joinGroupFollowerResponse(1, "consumer", "leader", Errors.NONE.code()), true);
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         client.prepareResponse(joinGroupFollowerResponse(1, "consumer", "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
         assertFalse(subscriptions.partitionAssignmentNeeded());
         assertEquals(Collections.singleton(tp), subscriptions.assignedPartitions());
         assertEquals(1, rebalanceListener.revokedCount);
@@ -705,25 +707,26 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // coordinator doesn't like the session timeout
         client.prepareResponse(joinGroupFollowerResponse(0, "consumer", "", Errors.INVALID_SESSION_TIMEOUT.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
     }
 
     @Test
     public void testCommitOffsetOnly() {
         subscriptions.assignFromUser(Arrays.asList(tp));
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())));
 
         AtomicBoolean success = new AtomicBoolean(false);
         coordinator.commitOffsetsAsync(Collections.singletonMap(tp, new OffsetAndMetadata(100L)), callback(success));
+        coordinator.invokeCompletedOffsetCommitCallbacks();
         assertTrue(success.get());
 
         assertEquals(100L, subscriptions.committed(tp).offset());
@@ -739,18 +742,18 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         subscriptions.seek(tp, 100);
 
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())));
         time.sleep(autoCommitIntervalMs);
-        consumerClient.poll(0);
+        coordinator.poll(time.milliseconds());
 
         assertEquals(100L, subscriptions.committed(tp).offset());
     }
@@ -765,7 +768,7 @@ public class ConsumerCoordinatorTest {
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // haven't joined, so should not cause a commit
@@ -774,13 +777,13 @@ public class ConsumerCoordinatorTest {
 
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         subscriptions.seek(tp, 100);
 
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())));
         time.sleep(autoCommitIntervalMs);
-        consumerClient.poll(0);
+        coordinator.poll(time.milliseconds());
 
         assertEquals(100L, subscriptions.committed(tp).offset());
     }
@@ -793,12 +796,12 @@ public class ConsumerCoordinatorTest {
         subscriptions.assignFromUser(Arrays.asList(tp));
         subscriptions.seek(tp, 100);
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())));
         time.sleep(autoCommitIntervalMs);
-        consumerClient.poll(0);
+        coordinator.poll(time.milliseconds());
 
         assertEquals(100L, subscriptions.committed(tp).offset());
     }
@@ -819,13 +822,13 @@ public class ConsumerCoordinatorTest {
         assertNull(subscriptions.committed(tp));
 
         // now find the coordinator
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // sleep only for the retry backoff
         time.sleep(retryBackoffMs);
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())));
-        consumerClient.poll(0);
+        coordinator.poll(time.milliseconds());
 
         assertEquals(100L, subscriptions.committed(tp).offset());
     }
@@ -834,13 +837,14 @@ public class ConsumerCoordinatorTest {
     public void testCommitOffsetMetadata() {
         subscriptions.assignFromUser(Arrays.asList(tp));
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())));
 
         AtomicBoolean success = new AtomicBoolean(false);
         coordinator.commitOffsetsAsync(Collections.singletonMap(tp, new OffsetAndMetadata(100L, "hello")), callback(success));
+        coordinator.invokeCompletedOffsetCommitCallbacks();
         assertTrue(success.get());
 
         assertEquals(100L, subscriptions.committed(tp).offset());
@@ -850,10 +854,11 @@ public class ConsumerCoordinatorTest {
     @Test
     public void testCommitOffsetAsyncWithDefaultCallback() {
         int invokedBeforeTest = defaultOffsetCommitCallback.invoked;
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())));
         coordinator.commitOffsetsAsync(Collections.singletonMap(tp, new OffsetAndMetadata(100L)), null);
+        coordinator.invokeCompletedOffsetCommitCallbacks();
         assertEquals(invokedBeforeTest + 1, defaultOffsetCommitCallback.invoked);
         assertNull(defaultOffsetCommitCallback.exception);
     }
@@ -863,12 +868,12 @@ public class ConsumerCoordinatorTest {
         // enable auto-assignment
         subscriptions.subscribe(Arrays.asList(topicName), rebalanceListener);
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(joinGroupFollowerResponse(1, "consumer", "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         // now switch to manual assignment
         client.prepareResponse(new LeaveGroupResponse(Errors.NONE.code()).toStruct());
@@ -888,29 +893,32 @@ public class ConsumerCoordinatorTest {
 
         AtomicBoolean success = new AtomicBoolean(false);
         coordinator.commitOffsetsAsync(Collections.singletonMap(tp, new OffsetAndMetadata(100L)), callback(success));
+        coordinator.invokeCompletedOffsetCommitCallbacks();
         assertTrue(success.get());
     }
 
     @Test
     public void testCommitOffsetAsyncFailedWithDefaultCallback() {
         int invokedBeforeTest = defaultOffsetCommitCallback.invoked;
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.GROUP_COORDINATOR_NOT_AVAILABLE.code())));
         coordinator.commitOffsetsAsync(Collections.singletonMap(tp, new OffsetAndMetadata(100L)), null);
+        coordinator.invokeCompletedOffsetCommitCallbacks();
         assertEquals(invokedBeforeTest + 1, defaultOffsetCommitCallback.invoked);
         assertTrue(defaultOffsetCommitCallback.exception instanceof RetriableCommitFailedException);
     }
 
     @Test
     public void testCommitOffsetAsyncCoordinatorNotAvailable() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // async commit with coordinator not available
         MockCommitCallback cb = new MockCommitCallback();
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.GROUP_COORDINATOR_NOT_AVAILABLE.code())));
         coordinator.commitOffsetsAsync(Collections.singletonMap(tp, new OffsetAndMetadata(100L)), cb);
+        coordinator.invokeCompletedOffsetCommitCallbacks();
 
         assertTrue(coordinator.coordinatorUnknown());
         assertEquals(1, cb.invoked);
@@ -919,13 +927,14 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testCommitOffsetAsyncNotCoordinator() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // async commit with not coordinator
         MockCommitCallback cb = new MockCommitCallback();
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NOT_COORDINATOR_FOR_GROUP.code())));
         coordinator.commitOffsetsAsync(Collections.singletonMap(tp, new OffsetAndMetadata(100L)), cb);
+        coordinator.invokeCompletedOffsetCommitCallbacks();
 
         assertTrue(coordinator.coordinatorUnknown());
         assertEquals(1, cb.invoked);
@@ -934,13 +943,14 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testCommitOffsetAsyncDisconnected() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // async commit with coordinator disconnected
         MockCommitCallback cb = new MockCommitCallback();
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())), true);
         coordinator.commitOffsetsAsync(Collections.singletonMap(tp, new OffsetAndMetadata(100L)), cb);
+        coordinator.invokeCompletedOffsetCommitCallbacks();
 
         assertTrue(coordinator.coordinatorUnknown());
         assertEquals(1, cb.invoked);
@@ -949,36 +959,36 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testCommitOffsetSyncNotCoordinator() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // sync commit with coordinator disconnected (should connect, get metadata, and then submit the commit request)
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NOT_COORDINATOR_FOR_GROUP.code())));
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())));
         coordinator.commitOffsetsSync(Collections.singletonMap(tp, new OffsetAndMetadata(100L)));
     }
 
     @Test
     public void testCommitOffsetSyncCoordinatorNotAvailable() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // sync commit with coordinator disconnected (should connect, get metadata, and then submit the commit request)
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.GROUP_COORDINATOR_NOT_AVAILABLE.code())));
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())));
         coordinator.commitOffsetsSync(Collections.singletonMap(tp, new OffsetAndMetadata(100L)));
     }
 
     @Test
     public void testCommitOffsetSyncCoordinatorDisconnected() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // sync commit with coordinator disconnected (should connect, get metadata, and then submit the commit request)
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())), true);
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.NONE.code())));
         coordinator.commitOffsetsSync(Collections.singletonMap(tp, new OffsetAndMetadata(100L)));
     }
@@ -986,7 +996,7 @@ public class ConsumerCoordinatorTest {
     @Test(expected = OffsetMetadataTooLarge.class)
     public void testCommitOffsetMetadataTooLarge() {
         // since offset metadata is provided by the user, we have to propagate the exception so they can handle it
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.OFFSET_METADATA_TOO_LARGE.code())));
@@ -996,7 +1006,7 @@ public class ConsumerCoordinatorTest {
     @Test(expected = CommitFailedException.class)
     public void testCommitOffsetIllegalGeneration() {
         // we cannot retry if a rebalance occurs before the commit completed
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.ILLEGAL_GENERATION.code())));
@@ -1006,7 +1016,7 @@ public class ConsumerCoordinatorTest {
     @Test(expected = CommitFailedException.class)
     public void testCommitOffsetUnknownMemberId() {
         // we cannot retry if a rebalance occurs before the commit completed
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.UNKNOWN_MEMBER_ID.code())));
@@ -1016,7 +1026,7 @@ public class ConsumerCoordinatorTest {
     @Test(expected = CommitFailedException.class)
     public void testCommitOffsetRebalanceInProgress() {
         // we cannot retry if a rebalance occurs before the commit completed
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(offsetCommitResponse(Collections.singletonMap(tp, Errors.REBALANCE_IN_PROGRESS.code())));
@@ -1025,7 +1035,7 @@ public class ConsumerCoordinatorTest {
 
     @Test(expected = KafkaException.class)
     public void testCommitOffsetSyncCallbackWithNonRetriableException() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         // sync commit with invalid partitions should throw if we have no callback
@@ -1035,7 +1045,7 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testRefreshOffset() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         subscriptions.assignFromUser(Arrays.asList(tp));
@@ -1048,7 +1058,7 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testRefreshOffsetLoadInProgress() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         subscriptions.assignFromUser(Arrays.asList(tp));
@@ -1062,13 +1072,13 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testRefreshOffsetNotCoordinatorForConsumer() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         subscriptions.assignFromUser(Arrays.asList(tp));
         subscriptions.needRefreshCommits();
         client.prepareResponse(offsetFetchResponse(tp, Errors.NOT_COORDINATOR_FOR_GROUP.code(), "", 100L));
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         client.prepareResponse(offsetFetchResponse(tp, Errors.NONE.code(), "", 100L));
         coordinator.refreshCommittedOffsetsIfNeeded();
         assertFalse(subscriptions.refreshCommitsNeeded());
@@ -1077,7 +1087,7 @@ public class ConsumerCoordinatorTest {
 
     @Test
     public void testRefreshOffsetWithNoFetchableOffsets() {
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         subscriptions.assignFromUser(Arrays.asList(tp));
@@ -1122,12 +1132,12 @@ public class ConsumerCoordinatorTest {
         metadata.setTopics(topics);
         subscriptions.needReassignment();
 
-        client.prepareResponse(consumerMetadataResponse(node, Errors.NONE.code()));
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
         coordinator.ensureCoordinatorReady();
 
         client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE.code()));
         client.prepareResponse(syncGroupResponse(Arrays.asList(tp), Errors.NONE.code()));
-        coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         metadata.update(TestUtils.singletonCluster(topicName, 2), time.milliseconds());
         assertTrue("Topic not found in metadata", metadata.containsTopic(topicName));
@@ -1150,6 +1160,7 @@ public class ConsumerCoordinatorTest {
         return new ConsumerCoordinator(
                 consumerClient,
                 groupId,
+                rebalanceTimeoutMs,
                 sessionTimeoutMs,
                 heartbeatIntervalMs,
                 assignors,
@@ -1166,7 +1177,7 @@ public class ConsumerCoordinatorTest {
                 excludeInternalTopics);
     }
 
-    private Struct consumerMetadataResponse(Node node, short error) {
+    private Struct groupCoordinatorResponse(Node node, short error) {
         GroupCoordinatorResponse response = new GroupCoordinatorResponse(error, node);
         return response.toStruct();
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java
index f0f2a97..8dcbde2 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java
@@ -15,9 +15,9 @@ package org.apache.kafka.clients.consumer.internals;
 import org.apache.kafka.clients.ClientResponse;
 import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
-import org.apache.kafka.common.errors.WakeupException;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.Node;
+import org.apache.kafka.common.errors.WakeupException;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.protocol.types.Struct;
@@ -76,22 +76,6 @@ public class ConsumerNetworkClientTest {
     }
 
     @Test
-    public void schedule() {
-        TestDelayedTask task = new TestDelayedTask();
-        consumerClient.schedule(task, time.milliseconds());
-        consumerClient.poll(0);
-        assertEquals(1, task.executions);
-
-        consumerClient.schedule(task, time.milliseconds() + 100);
-        consumerClient.poll(0);
-        assertEquals(1, task.executions);
-
-        time.sleep(100);
-        consumerClient.poll(0);
-        assertEquals(2, task.executions);
-    }
-
-    @Test
     public void wakeup() {
         RequestFuture<ClientResponse> future = consumerClient.send(node, ApiKeys.METADATA, heartbeatRequest());
         consumerClient.wakeup();
@@ -175,12 +159,4 @@ public class ConsumerNetworkClientTest {
         return response.toStruct();
     }
 
-    private static class TestDelayedTask implements DelayedTask {
-        int executions = 0;
-        @Override
-        public void run(long now) {
-            executions++;
-        }
-    }
-
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/test/java/org/apache/kafka/clients/consumer/internals/DelayedTaskQueueTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/DelayedTaskQueueTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/DelayedTaskQueueTest.java
deleted file mode 100644
index db87b66..0000000
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/DelayedTaskQueueTest.java
+++ /dev/null
@@ -1,89 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
- * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
- * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
- * License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
- * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations under the License.
- */
-package org.apache.kafka.clients.consumer.internals;
-
-import org.junit.Test;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-
-import static org.junit.Assert.assertEquals;
-
-public class DelayedTaskQueueTest {
-    private DelayedTaskQueue scheduler = new DelayedTaskQueue();
-    private ArrayList<DelayedTask> executed = new ArrayList<DelayedTask>();
-
-    @Test
-    public void testScheduling() {
-        // Empty scheduler
-        assertEquals(Long.MAX_VALUE, scheduler.nextTimeout(0));
-        scheduler.poll(0);
-        assertEquals(Collections.emptyList(), executed);
-
-        TestTask task1 = new TestTask();
-        TestTask task2 = new TestTask();
-        TestTask task3 = new TestTask();
-        scheduler.add(task1, 20);
-        assertEquals(20, scheduler.nextTimeout(0));
-        scheduler.add(task2, 10);
-        assertEquals(10, scheduler.nextTimeout(0));
-        scheduler.add(task3, 30);
-        assertEquals(10, scheduler.nextTimeout(0));
-
-        scheduler.poll(5);
-        assertEquals(Collections.emptyList(), executed);
-        assertEquals(5, scheduler.nextTimeout(5));
-
-        scheduler.poll(10);
-        assertEquals(Arrays.asList(task2), executed);
-        assertEquals(10, scheduler.nextTimeout(10));
-
-        scheduler.poll(20);
-        assertEquals(Arrays.asList(task2, task1), executed);
-        assertEquals(20, scheduler.nextTimeout(10));
-
-        scheduler.poll(30);
-        assertEquals(Arrays.asList(task2, task1, task3), executed);
-        assertEquals(Long.MAX_VALUE, scheduler.nextTimeout(30));
-    }
-
-    @Test
-    public void testRemove() {
-        TestTask task1 = new TestTask();
-        TestTask task2 = new TestTask();
-        TestTask task3 = new TestTask();
-        scheduler.add(task1, 20);
-        scheduler.add(task2, 10);
-        scheduler.add(task3, 30);
-        scheduler.add(task1, 40);
-        assertEquals(10, scheduler.nextTimeout(0));
-
-        scheduler.remove(task2);
-        assertEquals(20, scheduler.nextTimeout(0));
-
-        scheduler.remove(task1);
-        assertEquals(30, scheduler.nextTimeout(0));
-
-        scheduler.remove(task3);
-        assertEquals(Long.MAX_VALUE, scheduler.nextTimeout(0));
-    }
-
-    private class TestTask implements DelayedTask {
-        @Override
-        public void run(long now) {
-            executed.add(this);
-        }
-    }
-
-}

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index ba04cb5..5186618 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -438,6 +438,7 @@ public class FetcherTest {
         fetcherNoAutoReset.sendFetches();
         client.prepareResponse(fetchResponse(this.records.buffer(), Errors.OFFSET_OUT_OF_RANGE.code(), 100L, 0));
         consumerClient.poll(0);
+
         assertFalse(subscriptionsNoAutoReset.isOffsetResetNeeded(tp));
         try {
             fetcherNoAutoReset.fetchedRecords();

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java
index 75e68cc..0177c79 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java
@@ -28,8 +28,10 @@ public class HeartbeatTest {
 
     private long timeout = 300L;
     private long interval = 100L;
+    private long maxPollInterval = 900L;
+    private long retryBackoff = 10L;
     private MockTime time = new MockTime();
-    private Heartbeat heartbeat = new Heartbeat(timeout, interval, -1L);
+    private Heartbeat heartbeat = new Heartbeat(timeout, interval, maxPollInterval, retryBackoff);
 
     @Test
     public void testShouldHeartbeat() {
@@ -64,7 +66,7 @@ public class HeartbeatTest {
     public void testResetSession() {
         heartbeat.sentHeartbeat(time.milliseconds());
         time.sleep(305);
-        heartbeat.resetSessionTimeout(time.milliseconds());
+        heartbeat.resetTimeouts(time.milliseconds());
         assertFalse(heartbeat.sessionTimeoutExpired(time.milliseconds()));
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
index be7f974..766c745 100644
--- a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
@@ -55,8 +55,9 @@ public class RequestResponseTest {
                 createHeartBeatRequest(),
                 createHeartBeatRequest().getErrorResponse(0, new UnknownServerException()),
                 createHeartBeatResponse(),
-                createJoinGroupRequest(),
-                createJoinGroupRequest().getErrorResponse(0, new UnknownServerException()),
+                createJoinGroupRequest(1),
+                createJoinGroupRequest(0).getErrorResponse(0, new UnknownServerException()),
+                createJoinGroupRequest(1).getErrorResponse(1, new UnknownServerException()),
                 createJoinGroupResponse(),
                 createLeaveGroupRequest(),
                 createLeaveGroupRequest().getErrorResponse(0, new UnknownServerException()),
@@ -118,6 +119,7 @@ public class RequestResponseTest {
         checkSerialization(createOffsetCommitRequest(0).getErrorResponse(0, new UnknownServerException()), 0);
         checkSerialization(createOffsetCommitRequest(1), 1);
         checkSerialization(createOffsetCommitRequest(1).getErrorResponse(1, new UnknownServerException()), 1);
+        checkSerialization(createJoinGroupRequest(0), 0);
         checkSerialization(createUpdateMetadataRequest(0, null), 0);
         checkSerialization(createUpdateMetadataRequest(0, null).getErrorResponse(0, new UnknownServerException()), 0);
         checkSerialization(createUpdateMetadataRequest(1, null), 1);
@@ -236,11 +238,15 @@ public class RequestResponseTest {
         return new HeartbeatResponse(Errors.NONE.code());
     }
 
-    private AbstractRequest createJoinGroupRequest() {
+    private AbstractRequest createJoinGroupRequest(int version) {
         ByteBuffer metadata = ByteBuffer.wrap(new byte[] {});
         List<JoinGroupRequest.ProtocolMetadata> protocols = new ArrayList<>();
         protocols.add(new JoinGroupRequest.ProtocolMetadata("consumer-range", metadata));
-        return new JoinGroupRequest("group1", 30000, "consumer1", "consumer", protocols);
+        if (version == 0) {
+            return new JoinGroupRequest("group1", 30000, "consumer1", "consumer", protocols);
+        } else {
+            return new JoinGroupRequest("group1", 10000, 60000, "consumer1", "consumer", protocols);
+        }
     }
 
     private AbstractRequestResponse createJoinGroupResponse() {

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedConfig.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedConfig.java
index f5aa8ae..6e9d7b4 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedConfig.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedConfig.java
@@ -43,13 +43,31 @@ public class DistributedConfig extends WorkerConfig {
      * <code>session.timeout.ms</code>
      */
     public static final String SESSION_TIMEOUT_MS_CONFIG = "session.timeout.ms";
-    private static final String SESSION_TIMEOUT_MS_DOC = "The timeout used to detect failures when using Kafka's group management facilities.";
+    private static final String SESSION_TIMEOUT_MS_DOC = "The timeout used to detect worker failures." +
+            "The worker sends periodic heartbeats to indicate its liveness to the broker. If no heartbeats are " +
+            "received by the broker before the expiration of this session timeout, then the broker will remove the " +
+            "worker from the group and initiate a rebalance. Note that the value must be in the allowable range as " +
+            "configured in the broker configuration by <code>group.min.session.timeout.ms</code> " +
+            "and <code>group.max.session.timeout.ms</code>.";
 
     /**
      * <code>heartbeat.interval.ms</code>
      */
     public static final String HEARTBEAT_INTERVAL_MS_CONFIG = "heartbeat.interval.ms";
-    private static final String HEARTBEAT_INTERVAL_MS_DOC = "The expected time between heartbeats to the group coordinator when using Kafka's group management facilities. Heartbeats are used to ensure that the worker's session stays active and to facilitate rebalancing when new members join or leave the group. The value must be set lower than <code>session.timeout.ms</code>, but typically should be set no higher than 1/3 of that value. It can be adjusted even lower to control the expected time for normal rebalances.";
+    private static final String HEARTBEAT_INTERVAL_MS_DOC = "The expected time between heartbeats to the group " +
+            "coordinator when using Kafka's group management facilities. Heartbeats are used to ensure that the " +
+            "worker's session stays active and to facilitate rebalancing when new members join or leave the group. " +
+            "The value must be set lower than <code>session.timeout.ms</code>, but typically should be set no higher " +
+            "than 1/3 of that value. It can be adjusted even lower to control the expected time for normal rebalances.";
+
+    /**
+     * <code>rebalance.timeout.ms</code>
+     */
+    public static final String REBALANCE_TIMEOUT_MS_CONFIG = "rebalance.timeout.ms";
+    private static final String REBALANCE_TIMEOUT_MS_DOC = "The maximum allowed time for each worker to join the group " +
+            "once a rebalance has begun. This is basically a limit on the amount of time needed for all tasks to " +
+            "flush any pending data and commit offsets. If the timeout is exceeded, then the worker will be removed " +
+            "from the group, which will cause offset commit failures.";
 
     /**
      * <code>worker.sync.timeout.ms</code>
@@ -90,9 +108,14 @@ public class DistributedConfig extends WorkerConfig {
                 .define(GROUP_ID_CONFIG, ConfigDef.Type.STRING, ConfigDef.Importance.HIGH, GROUP_ID_DOC)
                 .define(SESSION_TIMEOUT_MS_CONFIG,
                         ConfigDef.Type.INT,
-                        30000,
+                        10000,
                         ConfigDef.Importance.HIGH,
                         SESSION_TIMEOUT_MS_DOC)
+                .define(REBALANCE_TIMEOUT_MS_CONFIG,
+                        ConfigDef.Type.INT,
+                        60000,
+                        ConfigDef.Importance.HIGH,
+                        REBALANCE_TIMEOUT_MS_DOC)
                 .define(HEARTBEAT_INTERVAL_MS_CONFIG,
                         ConfigDef.Type.INT,
                         3000,

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
index 9c74960..9114555 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
@@ -21,6 +21,7 @@ import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient;
 import org.apache.kafka.common.metrics.Measurable;
 import org.apache.kafka.common.metrics.MetricConfig;
 import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.requests.JoinGroupRequest;
 import org.apache.kafka.common.requests.JoinGroupRequest.ProtocolMetadata;
 import org.apache.kafka.common.utils.CircularIterator;
 import org.apache.kafka.common.utils.Time;
@@ -63,6 +64,7 @@ public final class WorkerCoordinator extends AbstractCoordinator implements Clos
      */
     public WorkerCoordinator(ConsumerNetworkClient client,
                              String groupId,
+                             int rebalanceTimeoutMs,
                              int sessionTimeoutMs,
                              int heartbeatIntervalMs,
                              Metrics metrics,
@@ -74,6 +76,7 @@ public final class WorkerCoordinator extends AbstractCoordinator implements Clos
                              WorkerRebalanceListener listener) {
         super(client,
                 groupId,
+                rebalanceTimeoutMs,
                 sessionTimeoutMs,
                 heartbeatIntervalMs,
                 metrics,
@@ -97,6 +100,32 @@ public final class WorkerCoordinator extends AbstractCoordinator implements Clos
         return "connect";
     }
 
+    public void poll(long timeout) {
+        // poll for io until the timeout expires
+        long now = time.milliseconds();
+        long deadline = now + timeout;
+
+        while (now <= deadline) {
+            if (coordinatorUnknown()) {
+                ensureCoordinatorReady();
+                now = time.milliseconds();
+            }
+
+            if (needRejoin()) {
+                ensureActiveGroup();
+                now = time.milliseconds();
+            }
+
+            pollHeartbeat(now);
+
+            // Note that because the network client is shared with the background heartbeat thread,
+            // we do not want to block in poll longer than the time to the next heartbeat.
+            long remaining = Math.max(0, deadline - now);
+            client.poll(Math.min(remaining, timeToNextHeartbeat(now)));
+            now = time.milliseconds();
+        }
+    }
+
     @Override
     public List<ProtocolMetadata> metadata() {
         configSnapshot = configStorage.snapshot();
@@ -238,12 +267,15 @@ public final class WorkerCoordinator extends AbstractCoordinator implements Clos
     }
 
     @Override
-    public boolean needRejoin() {
+    protected boolean needRejoin() {
         return super.needRejoin() || (assignmentSnapshot == null || assignmentSnapshot.failed()) || rejoinRequested;
     }
 
     public String memberId() {
-        return this.memberId;
+        Generation generation = generation();
+        if (generation != null)
+            return generation.memberId;
+        return JoinGroupRequest.UNKNOWN_MEMBER_ID;
     }
 
     @Override
@@ -252,7 +284,7 @@ public final class WorkerCoordinator extends AbstractCoordinator implements Clos
     }
 
     private boolean isLeader() {
-        return assignmentSnapshot != null && memberId.equals(assignmentSnapshot.leader());
+        return assignmentSnapshot != null && memberId().equals(assignmentSnapshot.leader());
     }
 
     public String ownerUrl(String connector) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java
index c21b9bf..a5213db 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java
@@ -104,6 +104,7 @@ public class WorkerGroupMember {
                     config.getInt(CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG));
             this.coordinator = new WorkerCoordinator(this.client,
                     config.getString(DistributedConfig.GROUP_ID_CONFIG),
+                    config.getInt(DistributedConfig.REBALANCE_TIMEOUT_MS_CONFIG),
                     config.getInt(DistributedConfig.SESSION_TIMEOUT_MS_CONFIG),
                     config.getInt(DistributedConfig.HEARTBEAT_INTERVAL_MS_CONFIG),
                     metrics,
@@ -131,23 +132,13 @@ public class WorkerGroupMember {
     }
 
     public void ensureActive() {
-        coordinator.ensureCoordinatorReady();
-        coordinator.ensureActiveGroup();
+        coordinator.poll(0);
     }
 
     public void poll(long timeout) {
         if (timeout < 0)
             throw new IllegalArgumentException("Timeout must not be negative");
-
-        // poll for new data until the timeout expires
-        long remaining = timeout;
-        while (remaining >= 0) {
-            long start = time.milliseconds();
-            coordinator.ensureCoordinatorReady();
-            coordinator.ensureActiveGroup();
-            client.poll(remaining);
-            remaining -= time.milliseconds() - start;
-        }
+        coordinator.poll(timeout);
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java
index 4c2ac40..3bfa83f 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java
@@ -67,6 +67,7 @@ public class WorkerCoordinatorTest {
 
     private String groupId = "test-group";
     private int sessionTimeoutMs = 10;
+    private int rebalanceTimeoutMs = 60;
     private int heartbeatIntervalMs = 2;
     private long retryBackoffMs = 100;
     private MockTime time;
@@ -98,6 +99,7 @@ public class WorkerCoordinatorTest {
 
         this.coordinator = new WorkerCoordinator(consumerClient,
                 groupId,
+                rebalanceTimeoutMs,
                 sessionTimeoutMs,
                 heartbeatIntervalMs,
                 metrics,

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/main/scala/kafka/api/ApiVersion.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/api/ApiVersion.scala b/core/src/main/scala/kafka/api/ApiVersion.scala
index 666d0e7..d955225 100644
--- a/core/src/main/scala/kafka/api/ApiVersion.scala
+++ b/core/src/main/scala/kafka/api/ApiVersion.scala
@@ -51,7 +51,10 @@ object ApiVersion {
     "0.10.0-IV0" -> KAFKA_0_10_0_IV0,
     // 0.10.0-IV1 is introduced for KIP-36(rack awareness) and KIP-43(SASL handshake).
     "0.10.0-IV1" -> KAFKA_0_10_0_IV1,
-    "0.10.0" -> KAFKA_0_10_0_IV1
+    "0.10.0" -> KAFKA_0_10_0_IV1,
+
+    // introduced for JoinGroup protocol change in KIP-62
+    "0.10.1-IV0" -> KAFKA_0_10_1_IV0
   )
 
   private val versionPattern = "\\.".r
@@ -111,3 +114,9 @@ case object KAFKA_0_10_0_IV1 extends ApiVersion {
   val messageFormatVersion: Byte = Message.MagicValue_V1
   val id: Int = 5
 }
+
+case object KAFKA_0_10_1_IV0 extends ApiVersion {
+  val version: String = "0.10.1-IV0"
+  val messageFormatVersion: Byte = Message.MagicValue_V1
+  val id: Int = 6
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/main/scala/kafka/coordinator/GroupCoordinator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/GroupCoordinator.scala b/core/src/main/scala/kafka/coordinator/GroupCoordinator.scala
index 0d02a4c..726426a 100644
--- a/core/src/main/scala/kafka/coordinator/GroupCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/GroupCoordinator.scala
@@ -93,6 +93,7 @@ class GroupCoordinator(val brokerId: Int,
                       memberId: String,
                       clientId: String,
                       clientHost: String,
+                      rebalanceTimeoutMs: Int,
                       sessionTimeoutMs: Int,
                       protocolType: String,
                       protocols: List[(String, Array[Byte])],
@@ -118,11 +119,11 @@ class GroupCoordinator(val brokerId: Int,
             responseCallback(joinError(memberId, Errors.UNKNOWN_MEMBER_ID.code))
           } else {
             val group = groupManager.addGroup(new GroupMetadata(groupId))
-            doJoinGroup(group, memberId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols, responseCallback)
+            doJoinGroup(group, memberId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, protocolType, protocols, responseCallback)
           }
 
         case Some(group) =>
-          doJoinGroup(group, memberId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols, responseCallback)
+          doJoinGroup(group, memberId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, protocolType, protocols, responseCallback)
       }
     }
   }
@@ -131,6 +132,7 @@ class GroupCoordinator(val brokerId: Int,
                           memberId: String,
                           clientId: String,
                           clientHost: String,
+                          rebalanceTimeoutMs: Int,
                           sessionTimeoutMs: Int,
                           protocolType: String,
                           protocols: List[(String, Array[Byte])],
@@ -154,7 +156,7 @@ class GroupCoordinator(val brokerId: Int,
 
           case PreparingRebalance =>
             if (memberId == JoinGroupRequest.UNKNOWN_MEMBER_ID) {
-              addMemberAndRebalance(sessionTimeoutMs, clientId, clientHost, protocolType, protocols, group, responseCallback)
+              addMemberAndRebalance(rebalanceTimeoutMs, sessionTimeoutMs, clientId, clientHost, protocolType, protocols, group, responseCallback)
             } else {
               val member = group.get(memberId)
               updateMemberAndRebalance(group, member, protocols, responseCallback)
@@ -162,7 +164,7 @@ class GroupCoordinator(val brokerId: Int,
 
           case AwaitingSync =>
             if (memberId == JoinGroupRequest.UNKNOWN_MEMBER_ID) {
-              addMemberAndRebalance(sessionTimeoutMs, clientId, clientHost, protocolType, protocols, group, responseCallback)
+              addMemberAndRebalance(rebalanceTimeoutMs, sessionTimeoutMs, clientId, clientHost, protocolType, protocols, group, responseCallback)
             } else {
               val member = group.get(memberId)
               if (member.matches(protocols)) {
@@ -189,7 +191,7 @@ class GroupCoordinator(val brokerId: Int,
           case Empty | Stable =>
             if (memberId == JoinGroupRequest.UNKNOWN_MEMBER_ID) {
               // if the member id is unknown, register the member to the group
-              addMemberAndRebalance(sessionTimeoutMs, clientId, clientHost, protocolType, protocols, group, responseCallback)
+              addMemberAndRebalance(rebalanceTimeoutMs, sessionTimeoutMs, clientId, clientHost, protocolType, protocols, group, responseCallback)
             } else {
               val member = group.get(memberId)
               if (memberId == group.leaderId || !member.matches(protocols)) {
@@ -256,7 +258,6 @@ class GroupCoordinator(val brokerId: Int,
 
           case AwaitingSync =>
             group.get(memberId).awaitingSyncCallback = responseCallback
-            completeAndScheduleNextHeartbeatExpiration(group, group.get(memberId))
 
             // if this is the leader, then we can attempt to persist state and transition to stable
             if (memberId == group.leaderId) {
@@ -299,7 +300,7 @@ class GroupCoordinator(val brokerId: Int,
     delayedGroupStore.foreach(groupManager.store)
   }
 
-  def handleLeaveGroup(groupId: String, consumerId: String, responseCallback: Short => Unit) {
+  def handleLeaveGroup(groupId: String, memberId: String, responseCallback: Short => Unit) {
     if (!isActive.get) {
       responseCallback(Errors.GROUP_COORDINATOR_NOT_AVAILABLE.code)
     } else if (!isCoordinatorForGroup(groupId)) {
@@ -317,10 +318,10 @@ class GroupCoordinator(val brokerId: Int,
 
         case Some(group) =>
           group synchronized {
-            if (group.is(Dead) || !group.has(consumerId)) {
+            if (group.is(Dead) || !group.has(memberId)) {
               responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
             } else {
-              val member = group.get(consumerId)
+              val member = group.get(memberId)
               removeHeartbeatForLeavingMember(group, member)
               onMemberFailure(group, member)
               responseCallback(Errors.NONE.code)
@@ -343,27 +344,49 @@ class GroupCoordinator(val brokerId: Int,
       responseCallback(Errors.NONE.code)
     } else {
       groupManager.getGroup(groupId) match {
-        case None => responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
+        case None =>
+          responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
+
         case Some(group) =>
           group synchronized {
-            if (group.is(Empty)) {
-              responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
-            } else if (group.is(Dead)) {
-              // if the group is marked as dead, it means some other thread has just removed the group
-              // from the coordinator metadata; this is likely that the group has migrated to some other
-              // coordinator OR the group is in a transient unstable phase. Let the member retry
-              // joining without the specified member id,
-              responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
-            } else if (!group.is(Stable)) {
-              responseCallback(Errors.REBALANCE_IN_PROGRESS.code)
-            } else if (!group.has(memberId)) {
-              responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
-            } else if (generationId != group.generationId) {
-              responseCallback(Errors.ILLEGAL_GENERATION.code)
-            } else {
-              val member = group.get(memberId)
-              completeAndScheduleNextHeartbeatExpiration(group, member)
-              responseCallback(Errors.NONE.code)
+            group.currentState match {
+              case Dead =>
+                // if the group is marked as dead, it means some other thread has just removed the group
+                // from the coordinator metadata; this is likely that the group has migrated to some other
+                // coordinator OR the group is in a transient unstable phase. Let the member retry
+                // joining without the specified member id,
+                responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
+
+              case Empty =>
+                responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
+
+              case AwaitingSync =>
+                if (!group.has(memberId))
+                  responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
+                else
+                  responseCallback(Errors.REBALANCE_IN_PROGRESS.code)
+
+              case PreparingRebalance =>
+                if (!group.has(memberId)) {
+                  responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
+                } else if (generationId != group.generationId) {
+                  responseCallback(Errors.ILLEGAL_GENERATION.code)
+                } else {
+                  val member = group.get(memberId)
+                  completeAndScheduleNextHeartbeatExpiration(group, member)
+                  responseCallback(Errors.REBALANCE_IN_PROGRESS.code)
+                }
+
+              case Stable =>
+                if (!group.has(memberId)) {
+                  responseCallback(Errors.UNKNOWN_MEMBER_ID.code)
+                } else if (generationId != group.generationId) {
+                  responseCallback(Errors.ILLEGAL_GENERATION.code)
+                } else {
+                  val member = group.get(memberId)
+                  completeAndScheduleNextHeartbeatExpiration(group, member)
+                  responseCallback(Errors.NONE.code)
+                }
             }
           }
       }
@@ -585,7 +608,8 @@ class GroupCoordinator(val brokerId: Int,
     heartbeatPurgatory.checkAndComplete(memberKey)
   }
 
-  private def addMemberAndRebalance(sessionTimeoutMs: Int,
+  private def addMemberAndRebalance(rebalanceTimeoutMs: Int,
+                                    sessionTimeoutMs: Int,
                                     clientId: String,
                                     clientHost: String,
                                     protocolType: String,
@@ -594,7 +618,8 @@ class GroupCoordinator(val brokerId: Int,
                                     callback: JoinCallback) = {
     // use the client-id with a random id suffix as the member-id
     val memberId = clientId + "-" + group.generateMemberIdSuffix
-    val member = new MemberMetadata(memberId, group.groupId, clientId, clientHost, sessionTimeoutMs, protocolType, protocols)
+    val member = new MemberMetadata(memberId, group.groupId, clientId, clientHost, rebalanceTimeoutMs,
+      sessionTimeoutMs, protocolType, protocols)
     member.awaitingJoinCallback = callback
     group.add(member.memberId, member)
     maybePrepareRebalance(group)
@@ -625,7 +650,7 @@ class GroupCoordinator(val brokerId: Int,
     group.transitionTo(PreparingRebalance)
     info("Preparing to restabilize group %s with old generation %s".format(group.groupId, group.generationId))
 
-    val rebalanceTimeout = group.rebalanceTimeout
+    val rebalanceTimeout = group.rebalanceTimeoutMs
     val delayedRebalance = new DelayedJoin(this, group, rebalanceTimeout)
     val groupKey = GroupKey(group.groupId)
     joinPurgatory.tryCompleteElseWatch(delayedRebalance, Seq(groupKey))
@@ -770,7 +795,8 @@ object GroupCoordinator {
     val groupConfig = GroupConfig(groupMinSessionTimeoutMs = config.groupMinSessionTimeoutMs,
       groupMaxSessionTimeoutMs = config.groupMaxSessionTimeoutMs)
 
-    val groupMetadataManager = new GroupMetadataManager(config.brokerId, offsetConfig, replicaManager, zkUtils, time)
+    val groupMetadataManager = new GroupMetadataManager(config.brokerId, config.interBrokerProtocolVersion,
+      offsetConfig, replicaManager, zkUtils, time)
     new GroupCoordinator(config.brokerId, groupConfig, offsetConfig, groupMetadataManager, heartbeatPurgatory, joinPurgatory, time)
   }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/main/scala/kafka/coordinator/GroupMetadata.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/GroupMetadata.scala b/core/src/main/scala/kafka/coordinator/GroupMetadata.scala
index b455964..c86c7f8 100644
--- a/core/src/main/scala/kafka/coordinator/GroupMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/GroupMetadata.scala
@@ -190,8 +190,8 @@ private[coordinator] class GroupMetadata(val groupId: String, initialState: Grou
 
   def allMemberMetadata = members.values.toList
 
-  def rebalanceTimeout = members.values.foldLeft(0) {(timeout, member) =>
-    timeout.max(member.sessionTimeoutMs)
+  def rebalanceTimeoutMs = members.values.foldLeft(0) { (timeout, member) =>
+    timeout.max(member.rebalanceTimeoutMs)
   }
 
   // TODO: decide if ids should be predictable or random

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/core/src/main/scala/kafka/coordinator/GroupMetadataManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/coordinator/GroupMetadataManager.scala b/core/src/main/scala/kafka/coordinator/GroupMetadataManager.scala
index ef8b295..cf8ae91 100644
--- a/core/src/main/scala/kafka/coordinator/GroupMetadataManager.scala
+++ b/core/src/main/scala/kafka/coordinator/GroupMetadataManager.scala
@@ -47,10 +47,12 @@ import java.util.concurrent.TimeUnit
 import java.util.concurrent.locks.ReentrantLock
 
 import com.yammer.metrics.core.Gauge
+import kafka.api.{ApiVersion, KAFKA_0_10_1_IV0}
 import kafka.utils.CoreUtils.inLock
 
 
 class GroupMetadataManager(val brokerId: Int,
+                           val interBrokerProtocolVersion: ApiVersion,
                            val config: OffsetConfig,
                            replicaManager: ReplicaManager,
                            zkUtils: ZkUtils,
@@ -175,9 +177,11 @@ class GroupMetadataManager(val brokerId: Int,
                         groupAssignment: Map[String, Array[Byte]],
                         responseCallback: Short => Unit): DelayedStore = {
     val (magicValue, timestamp) = getMessageFormatVersionAndTimestamp(partitionFor(group.groupId))
+    val groupMetadataValueVersion = if (interBrokerProtocolVersion < KAFKA_0_10_1_IV0) 0.toShort else GroupMetadataManager.CURRENT_GROUP_VALUE_SCHEMA_VERSION
+
     val message = new Message(
       key = GroupMetadataManager.groupMetadataKey(group.groupId),
-      bytes = GroupMetadataManager.groupMetadataValue(group, groupAssignment),
+      bytes = GroupMetadataManager.groupMetadataValue(group, groupAssignment, version = groupMetadataValueVersion),
       timestamp = timestamp,
       magicValue = magicValue)
 
@@ -704,30 +708,51 @@ object GroupMetadataManager {
   private val GROUP_METADATA_KEY_SCHEMA = new Schema(new Field("group", STRING))
   private val GROUP_KEY_GROUP_FIELD = GROUP_METADATA_KEY_SCHEMA.get("group")
 
-  private val MEMBER_METADATA_V0 = new Schema(new Field("member_id", STRING),
-    new Field("client_id", STRING),
-    new Field("client_host", STRING),
-    new Field("session_timeout", INT32),
-    new Field("subscription", BYTES),
-    new Field("assignment", BYTES))
-  private val MEMBER_METADATA_MEMBER_ID_V0 = MEMBER_METADATA_V0.get("member_id")
-  private val MEMBER_METADATA_CLIENT_ID_V0 = MEMBER_METADATA_V0.get("client_id")
-  private val MEMBER_METADATA_CLIENT_HOST_V0 = MEMBER_METADATA_V0.get("client_host")
-  private val MEMBER_METADATA_SESSION_TIMEOUT_V0 = MEMBER_METADATA_V0.get("session_timeout")
-  private val MEMBER_METADATA_SUBSCRIPTION_V0 = MEMBER_METADATA_V0.get("subscription")
-  private val MEMBER_METADATA_ASSIGNMENT_V0 = MEMBER_METADATA_V0.get("assignment")
-
-
-  private val GROUP_METADATA_VALUE_SCHEMA_V0 = new Schema(new Field("protocol_type", STRING),
-    new Field("generation", INT32),
-    new Field("protocol", NULLABLE_STRING),
-    new Field("leader", NULLABLE_STRING),
-    new Field("members", new ArrayOf(MEMBER_METADATA_V0)))
-  private val GROUP_METADATA_PROTOCOL_TYPE_V0 = GROUP_METADATA_VALUE_SCHEMA_V0.get("protocol_type")
-  private val GROUP_METADATA_GENERATION_V0 = GROUP_METADATA_VALUE_SCHEMA_V0.get("generation")
-  private val GROUP_METADATA_PROTOCOL_V0 = GROUP_METADATA_VALUE_SCHEMA_V0.get("protocol")
-  private val GROUP_METADATA_LEADER_V0 = GROUP_METADATA_VALUE_SCHEMA_V0.get("leader")
-  private val GROUP_METADATA_MEMBERS_V0 = GROUP_METADATA_VALUE_SCHEMA_V0.get("members")
+  private val MEMBER_ID_KEY = "member_id"
+  private val CLIENT_ID_KEY = "client_id"
+  private val CLIENT_HOST_KEY = "client_host"
+  private val REBALANCE_TIMEOUT_KEY = "rebalance_timeout"
+  private val SESSION_TIMEOUT_KEY = "session_timeout"
+  private val SUBSCRIPTION_KEY = "subscription"
+  private val ASSIGNMENT_KEY = "assignment"
+
+  private val MEMBER_METADATA_V0 = new Schema(
+    new Field(MEMBER_ID_KEY, STRING),
+    new Field(CLIENT_ID_KEY, STRING),
+    new Field(CLIENT_HOST_KEY, STRING),
+    new Field(SESSION_TIMEOUT_KEY, INT32),
+    new Field(SUBSCRIPTION_KEY, BYTES),
+    new Field(ASSIGNMENT_KEY, BYTES))
+
+  private val MEMBER_METADATA_V1 = new Schema(
+    new Field(MEMBER_ID_KEY, STRING),
+    new Field(CLIENT_ID_KEY, STRING),
+    new Field(CLIENT_HOST_KEY, STRING),
+    new Field(REBALANCE_TIMEOUT_KEY, INT32),
+    new Field(SESSION_TIMEOUT_KEY, INT32),
+    new Field(SUBSCRIPTION_KEY, BYTES),
+    new Field(ASSIGNMENT_KEY, BYTES))
+
+  private val PROTOCOL_TYPE_KEY = "protocol_type"
+  private val GENERATION_KEY = "generation"
+  private val PROTOCOL_KEY = "protocol"
+  private val LEADER_KEY = "leader"
+  private val MEMBERS_KEY = "members"
+
+  private val GROUP_METADATA_VALUE_SCHEMA_V0 = new Schema(
+    new Field(PROTOCOL_TYPE_KEY, STRING),
+    new Field(GENERATION_KEY, INT32),
+    new Field(PROTOCOL_KEY, NULLABLE_STRING),
+    new Field(LEADER_KEY, NULLABLE_STRING),
+    new Field(MEMBERS_KEY, new ArrayOf(MEMBER_METADATA_V0)))
+
+  private val GROUP_METADATA_VALUE_SCHEMA_V1 = new Schema(
+    new Field(PROTOCOL_TYPE_KEY, STRING),
+    new Field(GENERATION_KEY, INT32),
+    new Field(PROTOCOL_KEY, NULLABLE_STRING),
+    new Field(LEADER_KEY, NULLABLE_STRING),
+    new Field(MEMBERS_KEY, new ArrayOf(MEMBER_METADATA_V1)))
+
 
   // map of versions to key schemas as data types
   private val MESSAGE_TYPE_SCHEMAS = Map(
@@ -742,8 +767,10 @@ object GroupMetadataManager {
   private val CURRENT_OFFSET_VALUE_SCHEMA_VERSION = 1.toShort
 
   // map of version of group metadata value schemas
-  private val GROUP_VALUE_SCHEMAS = Map(0 -> GROUP_METADATA_VALUE_SCHEMA_V0)
-  private val CURRENT_GROUP_VALUE_SCHEMA_VERSION = 0.toShort
+  private val GROUP_VALUE_SCHEMAS = Map(
+    0 -> GROUP_METADATA_VALUE_SCHEMA_V0,
+    1 -> GROUP_METADATA_VALUE_SCHEMA_V1)
+  private val CURRENT_GROUP_VALUE_SCHEMA_VERSION = 1.toShort
 
   private val CURRENT_OFFSET_KEY_SCHEMA = schemaForKey(CURRENT_OFFSET_KEY_SCHEMA_VERSION)
   private val CURRENT_GROUP_KEY_SCHEMA = schemaForKey(CURRENT_GROUP_KEY_SCHEMA_VERSION)
@@ -830,40 +857,47 @@ object GroupMetadataManager {
    * Generates the payload for group metadata message from given offset and metadata
    * assuming the generation id, selected protocol, leader and member assignment are all available
    *
-   * @param groupMetadata
+   * @param groupMetadata current group metadata
+   * @param assignment the assignment for the rebalancing generation
+   * @param version the version of the value message to use
    * @return payload for offset commit message
    */
-  def groupMetadataValue(groupMetadata: GroupMetadata, assignment: Map[String, Array[Byte]]): Array[Byte] = {
-    // generate commit value with schema version 1
-    val value = new Struct(CURRENT_GROUP_VALUE_SCHEMA)
-    value.set(GROUP_METADATA_PROTOCOL_TYPE_V0, groupMetadata.protocolType.getOrElse(""))
-    value.set(GROUP_METADATA_GENERATION_V0, groupMetadata.generationId)
-    value.set(GROUP_METADATA_PROTOCOL_V0, groupMetadata.protocol)
-    value.set(GROUP_METADATA_LEADER_V0, groupMetadata.leaderId)
+  def groupMetadataValue(groupMetadata: GroupMetadata,
+                         assignment: Map[String, Array[Byte]],
+                         version: Short = 0): Array[Byte] = {
+    val value = if (version == 0) new Struct(GROUP_METADATA_VALUE_SCHEMA_V0) else new Struct(CURRENT_GROUP_VALUE_SCHEMA)
+
+    value.set(PROTOCOL_TYPE_KEY, groupMetadata.protocolType.getOrElse(""))
+    value.set(GENERATION_KEY, groupMetadata.generationId)
+    value.set(PROTOCOL_KEY, groupMetadata.protocol)
+    value.set(LEADER_KEY, groupMetadata.leaderId)
 
     val memberArray = groupMetadata.allMemberMetadata.map {
       case memberMetadata =>
-        val memberStruct = value.instance(GROUP_METADATA_MEMBERS_V0)
-        memberStruct.set(MEMBER_METADATA_MEMBER_ID_V0, memberMetadata.memberId)
-        memberStruct.set(MEMBER_METADATA_CLIENT_ID_V0, memberMetadata.clientId)
-        memberStruct.set(MEMBER_METADATA_CLIENT_HOST_V0, memberMetadata.clientHost)
-        memberStruct.set(MEMBER_METADATA_SESSION_TIMEOUT_V0, memberMetadata.sessionTimeoutMs)
+        val memberStruct = value.instance(MEMBERS_KEY)
+        memberStruct.set(MEMBER_ID_KEY, memberMetadata.memberId)
+        memberStruct.set(CLIENT_ID_KEY, memberMetadata.clientId)
+        memberStruct.set(CLIENT_HOST_KEY, memberMetadata.clientHost)
+        memberStruct.set(SESSION_TIMEOUT_KEY, memberMetadata.sessionTimeoutMs)
+
+        if (version > 0)
+          memberStruct.set(REBALANCE_TIMEOUT_KEY, memberMetadata.rebalanceTimeoutMs)
 
         val metadata = memberMetadata.metadata(groupMetadata.protocol)
-        memberStruct.set(MEMBER_METADATA_SUBSCRIPTION_V0, ByteBuffer.wrap(metadata))
+        memberStruct.set(SUBSCRIPTION_KEY, ByteBuffer.wrap(metadata))
 
         val memberAssignment = assignment(memberMetadata.memberId)
         assert(memberAssignment != null)
 
-        memberStruct.set(MEMBER_METADATA_ASSIGNMENT_V0, ByteBuffer.wrap(memberAssignment))
+        memberStruct.set(ASSIGNMENT_KEY, ByteBuffer.wrap(memberAssignment))
 
         memberStruct
     }
 
-    value.set(GROUP_METADATA_MEMBERS_V0, memberArray.toArray)
+    value.set(MEMBERS_KEY, memberArray.toArray)
 
     val byteBuffer = ByteBuffer.allocate(2 /* version */ + value.sizeOf)
-    byteBuffer.putShort(CURRENT_GROUP_VALUE_SCHEMA_VERSION)
+    byteBuffer.putShort(version)
     value.writeTo(byteBuffer)
     byteBuffer.array()
   }
@@ -944,31 +978,33 @@ object GroupMetadataManager {
       val valueSchema = schemaForGroup(version)
       val value = valueSchema.read(buffer)
 
-      if (version == 0) {
-        val protocolType = value.get(GROUP_METADATA_PROTOCOL_TYPE_V0).asInstanceOf[String]
+      if (version == 0 || version == 1) {
+        val protocolType = value.get(PROTOCOL_TYPE_KEY).asInstanceOf[String]
 
-        val memberMetadataArray = value.getArray(GROUP_METADATA_MEMBERS_V0)
+        val memberMetadataArray = value.getArray(MEMBERS_KEY)
         val initialState = if (memberMetadataArray.isEmpty) Empty else Stable
 
         val group = new GroupMetadata(groupId, initialState)
 
-        group.generationId = value.get(GROUP_METADATA_GENERATION_V0).asInstanceOf[Int]
-        group.leaderId = value.get(GROUP_METADATA_LEADER_V0).asInstanceOf[String]
-        group.protocol = value.get(GROUP_METADATA_PROTOCOL_V0).asInstanceOf[String]
+        group.generationId = value.get(GENERATION_KEY).asInstanceOf[Int]
+        group.leaderId = value.get(LEADER_KEY).asInstanceOf[String]
+        group.protocol = value.get(PROTOCOL_KEY).asInstanceOf[String]
 
         memberMetadataArray.foreach {
           case memberMetadataObj =>
             val memberMetadata = memberMetadataObj.asInstanceOf[Struct]
-            val memberId = memberMetadata.get(MEMBER_METADATA_MEMBER_ID_V0).asInstanceOf[String]
-            val clientId = memberMetadata.get(MEMBER_METADATA_CLIENT_ID_V0).asInstanceOf[String]
-            val clientHost = memberMetadata.get(MEMBER_METADATA_CLIENT_HOST_V0).asInstanceOf[String]
-            val sessionTimeout = memberMetadata.get(MEMBER_METADATA_SESSION_TIMEOUT_V0).asInstanceOf[Int]
-            val subscription = Utils.toArray(memberMetadata.get(MEMBER_METADATA_SUBSCRIPTION_V0).asInstanceOf[ByteBuffer])
+            val memberId = memberMetadata.get(MEMBER_ID_KEY).asInstanceOf[String]
+            val clientId = memberMetadata.get(CLIENT_ID_KEY).asInstanceOf[String]
+            val clientHost = memberMetadata.get(CLIENT_HOST_KEY).asInstanceOf[String]
+            val sessionTimeout = memberMetadata.get(SESSION_TIMEOUT_KEY).asInstanceOf[Int]
+            val rebalanceTimeout = if (version == 0) sessionTimeout else memberMetadata.get(REBALANCE_TIMEOUT_KEY).asInstanceOf[Int]
+
+            val subscription = Utils.toArray(memberMetadata.get(SUBSCRIPTION_KEY).asInstanceOf[ByteBuffer])
 
-            val member = new MemberMetadata(memberId, groupId, clientId, clientHost, sessionTimeout,
+            val member = new MemberMetadata(memberId, groupId, clientId, clientHost, rebalanceTimeout, sessionTimeout,
               protocolType, List((group.protocol, subscription)))
 
-            member.assignment = Utils.toArray(memberMetadata.get(MEMBER_METADATA_ASSIGNMENT_V0).asInstanceOf[ByteBuffer])
+            member.assignment = Utils.toArray(memberMetadata.get(ASSIGNMENT_KEY).asInstanceOf[ByteBuffer])
 
             group.add(memberId, member)
         }


[3/4] kafka git commit: KAFKA-3888: send consumer heartbeats from a background thread (KIP-62)

Posted by gu...@apache.org.
http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java
index b65a5b7..07edd3c 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java
@@ -22,6 +22,7 @@ import org.apache.kafka.common.errors.DisconnectException;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.errors.WakeupException;
 import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.ProtoUtils;
 import org.apache.kafka.common.requests.AbstractRequest;
 import org.apache.kafka.common.requests.RequestHeader;
 import org.apache.kafka.common.requests.RequestSend;
@@ -36,27 +37,34 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
- * Higher level consumer access to the network layer with basic support for futures and
- * task scheduling. This class is not thread-safe, except for wakeup().
+ * Higher level consumer access to the network layer with basic support for request futures. This class
+ * is thread-safe, but provides no synchronization for response callbacks. This guarantees that no locks
+ * are held when they are invoked.
  */
 public class ConsumerNetworkClient implements Closeable {
     private static final Logger log = LoggerFactory.getLogger(ConsumerNetworkClient.class);
 
+    // the mutable state of this class is protected by the object's monitor (excluding the wakeup
+    // flag and the request completion queue below).
     private final KafkaClient client;
-    private final AtomicBoolean wakeup = new AtomicBoolean(false);
-    private final DelayedTaskQueue delayedTasks = new DelayedTaskQueue();
     private final Map<Node, List<ClientRequest>> unsent = new HashMap<>();
     private final Metadata metadata;
     private final Time time;
     private final long retryBackoffMs;
     private final long unsentExpiryMs;
-
-    // this count is only accessed from the consumer's main thread
     private int wakeupDisabledCount = 0;
 
+    // when requests complete, they are transferred to this queue prior to invocation. The purpose
+    // is to avoid invoking them while holding the lock above.
+    private final ConcurrentLinkedQueue<RequestFutureCompletionHandler> pendingCompletion = new ConcurrentLinkedQueue<>();
+
+    // this flag allows the client to be safely woken up without waiting on the lock above. It is
+    // atomic to avoid the need to acquire the lock above in order to enable it concurrently.
+    private final AtomicBoolean wakeup = new AtomicBoolean(false);
 
     public ConsumerNetworkClient(KafkaClient client,
                                  Metadata metadata,
@@ -71,25 +79,6 @@ public class ConsumerNetworkClient implements Closeable {
     }
 
     /**
-     * Schedule a new task to be executed at the given time. This is "best-effort" scheduling and
-     * should only be used for coarse synchronization.
-     * @param task The task to be scheduled
-     * @param at The time it should run
-     */
-    public void schedule(DelayedTask task, long at) {
-        delayedTasks.add(task, at);
-    }
-
-    /**
-     * Unschedule a task. This will remove all instances of the task from the task queue.
-     * This is a no-op if the task is not scheduled.
-     * @param task The task to be unscheduled.
-     */
-    public void unschedule(DelayedTask task) {
-        delayedTasks.remove(task);
-    }
-
-    /**
      * Send a new request. Note that the request is not actually transmitted on the
      * network until one of the {@link #poll(long)} variants is invoked. At this
      * point the request will either be transmitted successfully or will fail.
@@ -104,25 +93,36 @@ public class ConsumerNetworkClient implements Closeable {
     public RequestFuture<ClientResponse> send(Node node,
                                               ApiKeys api,
                                               AbstractRequest request) {
+        return send(node, api, ProtoUtils.latestVersion(api.id), request);
+    }
+
+    private RequestFuture<ClientResponse> send(Node node,
+                                              ApiKeys api,
+                                              short version,
+                                              AbstractRequest request) {
         long now = time.milliseconds();
-        RequestFutureCompletionHandler future = new RequestFutureCompletionHandler();
-        RequestHeader header = client.nextRequestHeader(api);
+        RequestFutureCompletionHandler completionHandler = new RequestFutureCompletionHandler();
+        RequestHeader header = client.nextRequestHeader(api, version);
         RequestSend send = new RequestSend(node.idString(), header, request.toStruct());
-        put(node, new ClientRequest(now, true, send, future));
-        return future;
+        put(node, new ClientRequest(now, true, send, completionHandler));
+        return completionHandler.future;
     }
 
     private void put(Node node, ClientRequest request) {
-        List<ClientRequest> nodeUnsent = unsent.get(node);
-        if (nodeUnsent == null) {
-            nodeUnsent = new ArrayList<>();
-            unsent.put(node, nodeUnsent);
+        synchronized (this) {
+            List<ClientRequest> nodeUnsent = unsent.get(node);
+            if (nodeUnsent == null) {
+                nodeUnsent = new ArrayList<>();
+                unsent.put(node, nodeUnsent);
+            }
+            nodeUnsent.add(request);
         }
-        nodeUnsent.add(request);
     }
 
     public Node leastLoadedNode() {
-        return client.leastLoadedNode(time.milliseconds());
+        synchronized (this) {
+            return client.leastLoadedNode(time.milliseconds());
+        }
     }
 
     /**
@@ -149,6 +149,8 @@ public class ConsumerNetworkClient implements Closeable {
      * on the current poll if one is active, or the next poll.
      */
     public void wakeup() {
+        // wakeup should be safe without holding the client lock since it simply delegates to
+        // Selector's wakeup, which is threadsafe
         this.wakeup.set(true);
         this.client.wakeup();
     }
@@ -175,7 +177,7 @@ public class ConsumerNetworkClient implements Closeable {
         long remaining = timeout;
         long now = begin;
         do {
-            poll(remaining, now, true);
+            poll(remaining, now);
             now = time.milliseconds();
             long elapsed = now - begin;
             remaining = timeout - elapsed;
@@ -189,7 +191,7 @@ public class ConsumerNetworkClient implements Closeable {
      * @throws WakeupException if {@link #wakeup()} is called from another thread
      */
     public void poll(long timeout) {
-        poll(timeout, time.milliseconds(), true);
+        poll(timeout, time.milliseconds());
     }
 
     /**
@@ -198,7 +200,37 @@ public class ConsumerNetworkClient implements Closeable {
      * @param now current time in milliseconds
      */
     public void poll(long timeout, long now) {
-        poll(timeout, now, true);
+        // there may be handlers which need to be invoked if we woke up the previous call to poll
+        firePendingCompletedRequests();
+
+        synchronized (this) {
+            // send all the requests we can send now
+            trySend(now);
+
+            // ensure we don't poll any longer than the deadline for
+            // the next scheduled task
+            client.poll(timeout, now);
+            now = time.milliseconds();
+
+            // handle any disconnects by failing the active requests. note that disconnects must
+            // be checked immediately following poll since any subsequent call to client.ready()
+            // will reset the disconnect status
+            checkDisconnects(now);
+
+            // trigger wakeups after checking for disconnects so that the callbacks will be ready
+            // to be fired on the next call to poll()
+            maybeTriggerWakeup();
+
+            // try again to send requests since buffer space may have been
+            // cleared or a connect finished in the poll
+            trySend(now);
+
+            // fail requests that couldn't be sent if they have expired
+            failExpiredRequests(now);
+        }
+
+        // called without the lock to avoid deadlock potential if handlers need to acquire locks
+        firePendingCompletedRequests();
     }
 
     /**
@@ -208,49 +240,12 @@ public class ConsumerNetworkClient implements Closeable {
     public void pollNoWakeup() {
         disableWakeups();
         try {
-            poll(0, time.milliseconds(), false);
+            poll(0, time.milliseconds());
         } finally {
             enableWakeups();
         }
     }
 
-    private void poll(long timeout, long now, boolean executeDelayedTasks) {
-        // send all the requests we can send now
-        trySend(now);
-
-        // ensure we don't poll any longer than the deadline for
-        // the next scheduled task
-        timeout = Math.min(timeout, delayedTasks.nextTimeout(now));
-        clientPoll(timeout, now);
-        now = time.milliseconds();
-
-        // handle any disconnects by failing the active requests. note that disconnects must
-        // be checked immediately following poll since any subsequent call to client.ready()
-        // will reset the disconnect status
-        checkDisconnects(now);
-
-        // execute scheduled tasks
-        if (executeDelayedTasks)
-            delayedTasks.poll(now);
-
-        // try again to send requests since buffer space may have been
-        // cleared or a connect finished in the poll
-        trySend(now);
-
-        // fail requests that couldn't be sent if they have expired
-        failExpiredRequests(now);
-    }
-
-    /**
-     * Execute delayed tasks now.
-     * @param now current time in milliseconds
-     * @throws WakeupException if a wakeup has been requested
-     */
-    public void executeDelayedTasks(long now) {
-        delayedTasks.poll(now);
-        maybeTriggerWakeup();
-    }
-
     /**
      * Block until all pending requests from the given node have finished.
      * @param node The node to await requests from
@@ -267,9 +262,11 @@ public class ConsumerNetworkClient implements Closeable {
      * @return The number of pending requests
      */
     public int pendingRequestCount(Node node) {
-        List<ClientRequest> pending = unsent.get(node);
-        int unsentCount = pending == null ? 0 : pending.size();
-        return unsentCount + client.inFlightRequestCount(node.idString());
+        synchronized (this) {
+            List<ClientRequest> pending = unsent.get(node);
+            int unsentCount = pending == null ? 0 : pending.size();
+            return unsentCount + client.inFlightRequestCount(node.idString());
+        }
     }
 
     /**
@@ -278,10 +275,22 @@ public class ConsumerNetworkClient implements Closeable {
      * @return The total count of pending requests
      */
     public int pendingRequestCount() {
-        int total = 0;
-        for (List<ClientRequest> requests: unsent.values())
-            total += requests.size();
-        return total + client.inFlightRequestCount();
+        synchronized (this) {
+            int total = 0;
+            for (List<ClientRequest> requests: unsent.values())
+                total += requests.size();
+            return total + client.inFlightRequestCount();
+        }
+    }
+
+    private void firePendingCompletedRequests() {
+        for (;;) {
+            RequestFutureCompletionHandler completionHandler = pendingCompletion.poll();
+            if (completionHandler == null)
+                break;
+
+            completionHandler.fireCompletion();
+        }
     }
 
     private void checkDisconnects(long now) {
@@ -315,9 +324,8 @@ public class ConsumerNetworkClient implements Closeable {
             while (requestIterator.hasNext()) {
                 ClientRequest request = requestIterator.next();
                 if (request.createdTimeMs() < now - unsentExpiryMs) {
-                    RequestFutureCompletionHandler handler =
-                            (RequestFutureCompletionHandler) request.callback();
-                    handler.raise(new TimeoutException("Failed to send request after " + unsentExpiryMs + " ms."));
+                    RequestFutureCompletionHandler handler = (RequestFutureCompletionHandler) request.callback();
+                    handler.onFailure(new TimeoutException("Failed to send request after " + unsentExpiryMs + " ms."));
                     requestIterator.remove();
                 } else
                     break;
@@ -327,15 +335,20 @@ public class ConsumerNetworkClient implements Closeable {
         }
     }
 
-    protected void failUnsentRequests(Node node, RuntimeException e) {
+    public void failUnsentRequests(Node node, RuntimeException e) {
         // clear unsent requests to node and fail their corresponding futures
-        List<ClientRequest> unsentRequests = unsent.remove(node);
-        if (unsentRequests != null) {
-            for (ClientRequest request : unsentRequests) {
-                RequestFutureCompletionHandler handler = (RequestFutureCompletionHandler) request.callback();
-                handler.raise(e);
+        synchronized (this) {
+            List<ClientRequest> unsentRequests = unsent.remove(node);
+            if (unsentRequests != null) {
+                for (ClientRequest request : unsentRequests) {
+                    RequestFutureCompletionHandler handler = (RequestFutureCompletionHandler) request.callback();
+                    handler.onFailure(e);
+                }
             }
         }
+
+        // called without the lock to avoid deadlock potential
+        firePendingCompletedRequests();
     }
 
     private boolean trySend(long now) {
@@ -356,11 +369,6 @@ public class ConsumerNetworkClient implements Closeable {
         return requestsSent;
     }
 
-    private void clientPoll(long timeout, long now) {
-        client.poll(timeout, now);
-        maybeTriggerWakeup();
-    }
-
     private void maybeTriggerWakeup() {
         if (wakeupDisabledCount == 0 && wakeup.get()) {
             wakeup.set(false);
@@ -369,24 +377,30 @@ public class ConsumerNetworkClient implements Closeable {
     }
 
     public void disableWakeups() {
-        wakeupDisabledCount++;
+        synchronized (this) {
+            wakeupDisabledCount++;
+        }
     }
 
     public void enableWakeups() {
-        if (wakeupDisabledCount <= 0)
-            throw new IllegalStateException("Cannot enable wakeups since they were never disabled");
+        synchronized (this) {
+            if (wakeupDisabledCount <= 0)
+                throw new IllegalStateException("Cannot enable wakeups since they were never disabled");
 
-        wakeupDisabledCount--;
+            wakeupDisabledCount--;
 
-        // re-wakeup the client if the flag was set since previous wake-up call
-        // could be cleared by poll(0) while wakeups were disabled
-        if (wakeupDisabledCount == 0 && wakeup.get())
-            this.client.wakeup();
+            // re-wakeup the client if the flag was set since previous wake-up call
+            // could be cleared by poll(0) while wakeups were disabled
+            if (wakeupDisabledCount == 0 && wakeup.get())
+                this.client.wakeup();
+        }
     }
 
     @Override
     public void close() throws IOException {
-        client.close();
+        synchronized (this) {
+            client.close();
+        }
     }
 
     /**
@@ -395,7 +409,9 @@ public class ConsumerNetworkClient implements Closeable {
      * @param node Node to connect to if possible
      */
     public boolean connectionFailed(Node node) {
-        return client.connectionFailed(node);
+        synchronized (this) {
+            return client.connectionFailed(node);
+        }
     }
 
     /**
@@ -405,26 +421,45 @@ public class ConsumerNetworkClient implements Closeable {
      * @param node The node to connect to
      */
     public void tryConnect(Node node) {
-        client.ready(node, time.milliseconds());
+        synchronized (this) {
+            client.ready(node, time.milliseconds());
+        }
     }
 
-    public static class RequestFutureCompletionHandler
-            extends RequestFuture<ClientResponse>
-            implements RequestCompletionHandler {
+    public class RequestFutureCompletionHandler implements RequestCompletionHandler {
+        private final RequestFuture<ClientResponse> future;
+        private ClientResponse response;
+        private RuntimeException e;
 
-        @Override
-        public void onComplete(ClientResponse response) {
-            if (response.wasDisconnected()) {
+        public RequestFutureCompletionHandler() {
+            this.future = new RequestFuture<>();
+        }
+
+        public void fireCompletion() {
+            if (e != null) {
+                future.raise(e);
+            } else if (response.wasDisconnected()) {
                 ClientRequest request = response.request();
                 RequestSend send = request.request();
                 ApiKeys api = ApiKeys.forId(send.header().apiKey());
                 int correlation = send.header().correlationId();
                 log.debug("Cancelled {} request {} with correlation id {} due to node {} being disconnected",
                         api, request, correlation, send.destination());
-                raise(DisconnectException.INSTANCE);
+                future.raise(DisconnectException.INSTANCE);
             } else {
-                complete(response);
+                future.complete(response);
             }
         }
+
+        public void onFailure(RuntimeException e) {
+            this.e = e;
+            pendingCompletion.add(this);
+        }
+
+        @Override
+        public void onComplete(ClientResponse response) {
+            this.response = response;
+            pendingCompletion.add(this);
+        }
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/internals/DelayedTask.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/DelayedTask.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/DelayedTask.java
deleted file mode 100644
index 61663f8..0000000
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/DelayedTask.java
+++ /dev/null
@@ -1,24 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
- * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
- * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
- * License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
- * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations under the License.
- */
-
-package org.apache.kafka.clients.consumer.internals;
-
-
-public interface DelayedTask {
-
-    /**
-     * Execute the task.
-     * @param now current time in milliseconds
-     */
-    void run(long now);
-}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/internals/DelayedTaskQueue.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/DelayedTaskQueue.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/DelayedTaskQueue.java
deleted file mode 100644
index 61cab20..0000000
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/DelayedTaskQueue.java
+++ /dev/null
@@ -1,96 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
- * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
- * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
- * License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
- * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations under the License.
- */
-
-package org.apache.kafka.clients.consumer.internals;
-
-import java.util.Iterator;
-import java.util.PriorityQueue;
-
-/**
- * Tracks a set of tasks to be executed after a delay.
- */
-public class DelayedTaskQueue {
-
-    private PriorityQueue<Entry> tasks;
-
-    public DelayedTaskQueue() {
-        tasks = new PriorityQueue<Entry>();
-    }
-
-    /**
-     * Schedule a task for execution in the future.
-     *
-     * @param task the task to execute
-     * @param at the time at which to
-     */
-    public void add(DelayedTask task, long at) {
-        tasks.add(new Entry(task, at));
-    }
-
-    /**
-     * Remove a task from the queue if it is present
-     * @param task the task to be removed
-     * @returns true if a task was removed as a result of this call
-     */
-    public boolean remove(DelayedTask task) {
-        boolean wasRemoved = false;
-        Iterator<Entry> iterator = tasks.iterator();
-        while (iterator.hasNext()) {
-            Entry entry = iterator.next();
-            if (entry.task.equals(task)) {
-                iterator.remove();
-                wasRemoved = true;
-            }
-        }
-        return wasRemoved;
-    }
-
-    /**
-     * Get amount of time in milliseconds until the next event. Returns Long.MAX_VALUE if no tasks are scheduled.
-     *
-     * @return the remaining time in milliseconds
-     */
-    public long nextTimeout(long now) {
-        if (tasks.isEmpty())
-            return Long.MAX_VALUE;
-        else
-            return Math.max(tasks.peek().timeout - now, 0);
-    }
-
-    /**
-     * Run any ready tasks.
-     *
-     * @param now the current time
-     */
-    public void poll(long now) {
-        while (!tasks.isEmpty() && tasks.peek().timeout <= now) {
-            Entry entry = tasks.poll();
-            entry.task.run(now);
-        }
-    }
-
-    private static class Entry implements Comparable<Entry> {
-        DelayedTask task;
-        long timeout;
-
-        public Entry(DelayedTask task, long timeout) {
-            this.task = task;
-            this.timeout = timeout;
-        }
-
-        @Override
-        public int compareTo(Entry entry) {
-            return Long.compare(timeout, entry.timeout);
-        }
-    }
-}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index 913ce9e..84278c6 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -65,6 +65,7 @@ import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ConcurrentLinkedQueue;
 
 /**
  * This class manage the fetching process with the brokers.
@@ -84,7 +85,7 @@ public class Fetcher<K, V> {
     private final Metadata metadata;
     private final FetchManagerMetrics sensors;
     private final SubscriptionState subscriptions;
-    private final List<CompletedFetch> completedFetches;
+    private final ConcurrentLinkedQueue<CompletedFetch> completedFetches;
     private final Deserializer<K> keyDeserializer;
     private final Deserializer<V> valueDeserializer;
 
@@ -115,7 +116,7 @@ public class Fetcher<K, V> {
         this.checkCrcs = checkCrcs;
         this.keyDeserializer = keyDeserializer;
         this.valueDeserializer = valueDeserializer;
-        this.completedFetches = new ArrayList<>();
+        this.completedFetches = new ConcurrentLinkedQueue<>();
         this.sensors = new FetchManagerMetrics(metrics, metricGrpPrefix);
         this.retryBackoffMs = retryBackoffMs;
     }
@@ -127,7 +128,8 @@ public class Fetcher<K, V> {
     public void sendFetches() {
         for (Map.Entry<Node, FetchRequest> fetchEntry: createFetchRequests().entrySet()) {
             final FetchRequest request = fetchEntry.getValue();
-            client.send(fetchEntry.getKey(), ApiKeys.FETCH, request)
+            final Node fetchTarget = fetchEntry.getKey();
+            client.send(fetchTarget, ApiKeys.FETCH, request)
                     .addListener(new RequestFutureListener<ClientResponse>() {
                         @Override
                         public void onSuccess(ClientResponse resp) {
@@ -148,7 +150,7 @@ public class Fetcher<K, V> {
 
                         @Override
                         public void onFailure(RuntimeException e) {
-                            log.debug("Fetch failed", e);
+                            log.debug("Fetch request to {} failed", fetchTarget, e);
                         }
                     });
         }
@@ -353,16 +355,14 @@ public class Fetcher<K, V> {
         } else {
             Map<TopicPartition, List<ConsumerRecord<K, V>>> drained = new HashMap<>();
             int recordsRemaining = maxPollRecords;
-            Iterator<CompletedFetch> completedFetchesIterator = completedFetches.iterator();
 
             while (recordsRemaining > 0) {
                 if (nextInLineRecords == null || nextInLineRecords.isEmpty()) {
-                    if (!completedFetchesIterator.hasNext())
+                    CompletedFetch completedFetch = completedFetches.poll();
+                    if (completedFetch == null)
                         break;
 
-                    CompletedFetch completion = completedFetchesIterator.next();
-                    completedFetchesIterator.remove();
-                    nextInLineRecords = parseFetchedData(completion);
+                    nextInLineRecords = parseFetchedData(completedFetch);
                 } else {
                     recordsRemaining -= append(drained, nextInLineRecords, recordsRemaining);
                 }
@@ -510,6 +510,8 @@ public class Fetcher<K, V> {
                 long position = this.subscriptions.position(partition);
                 fetch.put(partition, new FetchRequest.PartitionData(position, this.fetchSize));
                 log.trace("Added fetch request for partition {} at offset {}", partition, position);
+            } else {
+                log.trace("Skipping fetch for partition {} because there is an inflight request to {}", partition, node);
             }
         }
 
@@ -845,4 +847,5 @@ public class Fetcher<K, V> {
             recordsFetched.record(records);
         }
     }
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java
index 79e17e2..dff1006 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java
@@ -16,26 +16,41 @@ package org.apache.kafka.clients.consumer.internals;
  * A helper class for managing the heartbeat to the coordinator
  */
 public final class Heartbeat {
-    private final long timeout;
-    private final long interval;
+    private final long sessionTimeout;
+    private final long heartbeatInterval;
+    private final long maxPollInterval;
+    private final long retryBackoffMs;
 
-    private long lastHeartbeatSend;
+    private volatile long lastHeartbeatSend; // volatile since it is read by metrics
     private long lastHeartbeatReceive;
     private long lastSessionReset;
+    private long lastPoll;
+    private boolean heartbeatFailed;
 
-    public Heartbeat(long timeout,
-                     long interval,
-                     long now) {
-        if (interval >= timeout)
+    public Heartbeat(long sessionTimeout,
+                     long heartbeatInterval,
+                     long maxPollInterval,
+                     long retryBackoffMs) {
+        if (heartbeatInterval >= sessionTimeout)
             throw new IllegalArgumentException("Heartbeat must be set lower than the session timeout");
 
-        this.timeout = timeout;
-        this.interval = interval;
-        this.lastSessionReset = now;
+        this.sessionTimeout = sessionTimeout;
+        this.heartbeatInterval = heartbeatInterval;
+        this.maxPollInterval = maxPollInterval;
+        this.retryBackoffMs = retryBackoffMs;
+    }
+
+    public void poll(long now) {
+        this.lastPoll = now;
     }
 
     public void sentHeartbeat(long now) {
         this.lastHeartbeatSend = now;
+        this.heartbeatFailed = false;
+    }
+
+    public void failHeartbeat() {
+        this.heartbeatFailed = true;
     }
 
     public void receiveHeartbeat(long now) {
@@ -52,23 +67,34 @@ public final class Heartbeat {
 
     public long timeToNextHeartbeat(long now) {
         long timeSinceLastHeartbeat = now - Math.max(lastHeartbeatSend, lastSessionReset);
+        final long delayToNextHeartbeat;
+        if (heartbeatFailed)
+            delayToNextHeartbeat = retryBackoffMs;
+        else
+            delayToNextHeartbeat = heartbeatInterval;
 
-        if (timeSinceLastHeartbeat > interval)
+        if (timeSinceLastHeartbeat > delayToNextHeartbeat)
             return 0;
         else
-            return interval - timeSinceLastHeartbeat;
+            return delayToNextHeartbeat - timeSinceLastHeartbeat;
     }
 
     public boolean sessionTimeoutExpired(long now) {
-        return now - Math.max(lastSessionReset, lastHeartbeatReceive) > timeout;
+        return now - Math.max(lastSessionReset, lastHeartbeatReceive) > sessionTimeout;
     }
 
     public long interval() {
-        return interval;
+        return heartbeatInterval;
     }
 
-    public void resetSessionTimeout(long now) {
+    public void resetTimeouts(long now) {
         this.lastSessionReset = now;
+        this.lastPoll = now;
+        this.heartbeatFailed = false;
+    }
+
+    public boolean pollTimeoutExpired(long now) {
+        return now - lastPoll > maxPollInterval;
     }
 
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFuture.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFuture.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFuture.java
index 71c16fa..b21d13e 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFuture.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFuture.java
@@ -196,7 +196,7 @@ public class RequestFuture<T> {
     }
 
     public static RequestFuture<Void> voidSuccess() {
-        RequestFuture<Void> future = new RequestFuture<Void>();
+        RequestFuture<Void> future = new RequestFuture<>();
         future.complete(null);
         return future;
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java b/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java
index d27ec8a..313477f 100644
--- a/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java
+++ b/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java
@@ -572,9 +572,28 @@ public class Protocol {
                                                                             new ArrayOf(JOIN_GROUP_REQUEST_PROTOCOL_V0),
                                                                             "List of protocols that the member supports"));
 
+    public static final Schema JOIN_GROUP_REQUEST_V1 = new Schema(new Field("group_id",
+                                                                            STRING,
+                                                                            "The group id."),
+                                                                  new Field("session_timeout",
+                                                                            INT32,
+                                                                            "The coordinator considers the consumer dead if it receives no heartbeat after this timeout in ms."),
+                                                                  new Field("rebalance_timeout",
+                                                                            INT32,
+                                                                            "The maximum time that the coordinator will wait for each member to rejoin when rebalancing the group"),
+                                                                  new Field("member_id",
+                                                                            STRING,
+                                                                            "The assigned consumer id or an empty string for a new consumer."),
+                                                                  new Field("protocol_type",
+                                                                            STRING,
+                                                                            "Unique name for class of protocols implemented by group"),
+                                                                  new Field("group_protocols",
+                                                                            new ArrayOf(JOIN_GROUP_REQUEST_PROTOCOL_V0),
+                                                                            "List of protocols that the member supports"));
 
     public static final Schema JOIN_GROUP_RESPONSE_MEMBER_V0 = new Schema(new Field("member_id", STRING),
                                                                           new Field("member_metadata", BYTES));
+
     public static final Schema JOIN_GROUP_RESPONSE_V0 = new Schema(new Field("error_code", INT16),
                                                                    new Field("generation_id",
                                                                              INT32,
@@ -591,8 +610,10 @@ public class Protocol {
                                                                    new Field("members",
                                                                              new ArrayOf(JOIN_GROUP_RESPONSE_MEMBER_V0)));
 
-    public static final Schema[] JOIN_GROUP_REQUEST = new Schema[] {JOIN_GROUP_REQUEST_V0};
-    public static final Schema[] JOIN_GROUP_RESPONSE = new Schema[] {JOIN_GROUP_RESPONSE_V0};
+    public static final Schema JOIN_GROUP_RESPONSE_V1 = JOIN_GROUP_RESPONSE_V0;
+
+    public static final Schema[] JOIN_GROUP_REQUEST = new Schema[] {JOIN_GROUP_REQUEST_V0, JOIN_GROUP_REQUEST_V1};
+    public static final Schema[] JOIN_GROUP_RESPONSE = new Schema[] {JOIN_GROUP_RESPONSE_V0, JOIN_GROUP_RESPONSE_V1};
 
     /* SyncGroup api */
     public static final Schema SYNC_GROUP_REQUEST_MEMBER_V0 = new Schema(new Field("member_id", STRING),

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupRequest.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupRequest.java
index 14a6c1d..2845ee0 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupRequest.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupRequest.java
@@ -24,10 +24,11 @@ import java.util.Collections;
 import java.util.List;
 
 public class JoinGroupRequest extends AbstractRequest {
-    
+
     private static final Schema CURRENT_SCHEMA = ProtoUtils.currentRequestSchema(ApiKeys.JOIN_GROUP.id);
     private static final String GROUP_ID_KEY_NAME = "group_id";
     private static final String SESSION_TIMEOUT_KEY_NAME = "session_timeout";
+    private static final String REBALANCE_TIMEOUT_KEY_NAME = "rebalance_timeout";
     private static final String MEMBER_ID_KEY_NAME = "member_id";
     private static final String PROTOCOL_TYPE_KEY_NAME = "protocol_type";
     private static final String GROUP_PROTOCOLS_KEY_NAME = "group_protocols";
@@ -38,6 +39,7 @@ public class JoinGroupRequest extends AbstractRequest {
 
     private final String groupId;
     private final int sessionTimeout;
+    private final int rebalanceTimeout;
     private final String memberId;
     private final String protocolType;
     private final List<ProtocolMetadata> groupProtocols;
@@ -60,14 +62,40 @@ public class JoinGroupRequest extends AbstractRequest {
         }
     }
 
+    // v0 constructor
+    @Deprecated
+    public JoinGroupRequest(String groupId,
+                            int sessionTimeout,
+                            String memberId,
+                            String protocolType,
+                            List<ProtocolMetadata> groupProtocols) {
+        this(0, groupId, sessionTimeout, sessionTimeout, memberId, protocolType, groupProtocols);
+    }
+
     public JoinGroupRequest(String groupId,
                             int sessionTimeout,
+                            int rebalanceTimeout,
                             String memberId,
                             String protocolType,
                             List<ProtocolMetadata> groupProtocols) {
-        super(new Struct(CURRENT_SCHEMA));
+        this(1, groupId, sessionTimeout, rebalanceTimeout, memberId, protocolType, groupProtocols);
+    }
+
+    private JoinGroupRequest(int version,
+                             String groupId,
+                             int sessionTimeout,
+                             int rebalanceTimeout,
+                             String memberId,
+                             String protocolType,
+                             List<ProtocolMetadata> groupProtocols) {
+        super(new Struct(ProtoUtils.requestSchema(ApiKeys.JOIN_GROUP.id, version)));
+
         struct.set(GROUP_ID_KEY_NAME, groupId);
         struct.set(SESSION_TIMEOUT_KEY_NAME, sessionTimeout);
+
+        if (version >= 1)
+            struct.set(REBALANCE_TIMEOUT_KEY_NAME, rebalanceTimeout);
+
         struct.set(MEMBER_ID_KEY_NAME, memberId);
         struct.set(PROTOCOL_TYPE_KEY_NAME, protocolType);
 
@@ -82,6 +110,7 @@ public class JoinGroupRequest extends AbstractRequest {
         struct.set(GROUP_PROTOCOLS_KEY_NAME, groupProtocolsList.toArray());
         this.groupId = groupId;
         this.sessionTimeout = sessionTimeout;
+        this.rebalanceTimeout = rebalanceTimeout;
         this.memberId = memberId;
         this.protocolType = protocolType;
         this.groupProtocols = groupProtocols;
@@ -89,8 +118,17 @@ public class JoinGroupRequest extends AbstractRequest {
 
     public JoinGroupRequest(Struct struct) {
         super(struct);
+
         groupId = struct.getString(GROUP_ID_KEY_NAME);
         sessionTimeout = struct.getInt(SESSION_TIMEOUT_KEY_NAME);
+
+        if (struct.hasField(REBALANCE_TIMEOUT_KEY_NAME))
+            // rebalance timeout is added in v1
+            rebalanceTimeout = struct.getInt(REBALANCE_TIMEOUT_KEY_NAME);
+        else
+            // v0 had no rebalance timeout but used session timeout implicitly
+            rebalanceTimeout = sessionTimeout;
+
         memberId = struct.getString(MEMBER_ID_KEY_NAME);
         protocolType = struct.getString(PROTOCOL_TYPE_KEY_NAME);
 
@@ -107,13 +145,16 @@ public class JoinGroupRequest extends AbstractRequest {
     public AbstractRequestResponse getErrorResponse(int versionId, Throwable e) {
         switch (versionId) {
             case 0:
+            case 1:
                 return new JoinGroupResponse(
+                        versionId,
                         Errors.forException(e).code(),
                         JoinGroupResponse.UNKNOWN_GENERATION_ID,
                         JoinGroupResponse.UNKNOWN_PROTOCOL,
                         JoinGroupResponse.UNKNOWN_MEMBER_ID, // memberId
                         JoinGroupResponse.UNKNOWN_MEMBER_ID, // leaderId
                         Collections.<String, ByteBuffer>emptyMap());
+
             default:
                 throw new IllegalArgumentException(String.format("Version %d is not valid. Valid versions for %s are 0 to %d",
                         versionId, this.getClass().getSimpleName(), ProtoUtils.latestVersion(ApiKeys.JOIN_GROUP.id)));
@@ -128,6 +169,10 @@ public class JoinGroupRequest extends AbstractRequest {
         return sessionTimeout;
     }
 
+    public int rebalanceTimeout() {
+        return rebalanceTimeout;
+    }
+
     public String memberId() {
         return memberId;
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupResponse.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupResponse.java
index dd829ed..8895ace 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupResponse.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupResponse.java
@@ -24,7 +24,8 @@ import java.util.List;
 import java.util.Map;
 
 public class JoinGroupResponse extends AbstractRequestResponse {
-    
+
+    private static final short CURRENT_VERSION = ProtoUtils.latestVersion(ApiKeys.JOIN_GROUP.id);
     private static final Schema CURRENT_SCHEMA = ProtoUtils.currentResponseSchema(ApiKeys.JOIN_GROUP.id);
     private static final String ERROR_CODE_KEY_NAME = "error_code";
 
@@ -65,7 +66,17 @@ public class JoinGroupResponse extends AbstractRequestResponse {
                              String memberId,
                              String leaderId,
                              Map<String, ByteBuffer> groupMembers) {
-        super(new Struct(CURRENT_SCHEMA));
+        this(CURRENT_VERSION, errorCode, generationId, groupProtocol, memberId, leaderId, groupMembers);
+    }
+
+    public JoinGroupResponse(int version,
+                             short errorCode,
+                             int generationId,
+                             String groupProtocol,
+                             String memberId,
+                             String leaderId,
+                             Map<String, ByteBuffer> groupMembers) {
+        super(new Struct(ProtoUtils.responseSchema(ApiKeys.JOIN_GROUP.id, version)));
 
         struct.set(ERROR_CODE_KEY_NAME, errorCode);
         struct.set(GENERATION_ID_KEY_NAME, generationId);

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchResponse.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchResponse.java
index a76f48e..6cf93a0 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchResponse.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchResponse.java
@@ -50,8 +50,6 @@ public class OffsetFetchResponse extends AbstractRequestResponse {
      *  UNKNOWN_TOPIC_OR_PARTITION (3)  <- only for request v0
      *  GROUP_LOAD_IN_PROGRESS (14)
      *  NOT_COORDINATOR_FOR_GROUP (16)
-     *  ILLEGAL_GENERATION (22)
-     *  UNKNOWN_MEMBER_ID (25)
      *  TOPIC_AUTHORIZATION_FAILED (29)
      *  GROUP_AUTHORIZATION_FAILED (30)
      */

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
index 8b52664..8d2ac00 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
@@ -90,8 +90,7 @@ public class KafkaConsumerTest {
         final int oldInitCount = MockMetricsReporter.INIT_COUNT.get();
         final int oldCloseCount = MockMetricsReporter.CLOSE_COUNT.get();
         try {
-            KafkaConsumer<byte[], byte[]> consumer = new KafkaConsumer<>(
-                    props, new ByteArrayDeserializer(), new ByteArrayDeserializer());
+            new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer());
         } catch (KafkaException e) {
             assertEquals(oldInitCount + 1, MockMetricsReporter.INIT_COUNT.get());
             assertEquals(oldCloseCount + 1, MockMetricsReporter.CLOSE_COUNT.get());
@@ -314,17 +313,17 @@ public class KafkaConsumerTest {
         props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999");
         props.setProperty(ConsumerConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName());
 
-        return new KafkaConsumer<byte[], byte[]>(
-            props, new ByteArrayDeserializer(), new ByteArrayDeserializer());
+        return new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer());
     }
 
     @Test
-    public void verifyHeartbeatSent() {
+    public void verifyHeartbeatSent() throws Exception {
         String topic = "topic";
         TopicPartition partition = new TopicPartition(topic, 0);
 
+        int rebalanceTimeoutMs = 60000;
         int sessionTimeoutMs = 30000;
-        int heartbeatIntervalMs = 3000;
+        int heartbeatIntervalMs = 1000;
         int autoCommitIntervalMs = 10000;
 
         Time time = new MockTime();
@@ -337,7 +336,7 @@ public class KafkaConsumerTest {
         PartitionAssignor assignor = new RoundRobinAssignor();
 
         final KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor,
-                sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
+                rebalanceTimeoutMs, sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
 
         consumer.subscribe(Arrays.asList(topic), new ConsumerRebalanceListener() {
             @Override
@@ -370,9 +369,6 @@ public class KafkaConsumerTest {
         consumer.poll(0);
         assertEquals(Collections.singleton(partition), consumer.assignment());
 
-        // heartbeat interval is 2 seconds
-        time.sleep(heartbeatIntervalMs);
-
         final AtomicBoolean heartbeatReceived = new AtomicBoolean(false);
         client.prepareResponseFrom(new MockClient.RequestMatcher() {
             @Override
@@ -382,18 +378,23 @@ public class KafkaConsumerTest {
             }
         }, new HeartbeatResponse(Errors.NONE.code()).toStruct(), coordinator);
 
+        // heartbeat interval is 2 seconds
+        time.sleep(heartbeatIntervalMs);
+        Thread.sleep(heartbeatIntervalMs);
+
         consumer.poll(0);
 
         assertTrue(heartbeatReceived.get());
     }
 
     @Test
-    public void verifyHeartbeatSentWhenFetchedDataReady() {
+    public void verifyHeartbeatSentWhenFetchedDataReady() throws Exception {
         String topic = "topic";
         TopicPartition partition = new TopicPartition(topic, 0);
 
+        int rebalanceTimeoutMs = 60000;
         int sessionTimeoutMs = 30000;
-        int heartbeatIntervalMs = 3000;
+        int heartbeatIntervalMs = 1000;
         int autoCommitIntervalMs = 10000;
 
         Time time = new MockTime();
@@ -406,7 +407,7 @@ public class KafkaConsumerTest {
         PartitionAssignor assignor = new RoundRobinAssignor();
 
         final KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor,
-                sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
+                rebalanceTimeoutMs, sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
         consumer.subscribe(Arrays.asList(topic), new ConsumerRebalanceListener() {
             @Override
             public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
@@ -438,8 +439,6 @@ public class KafkaConsumerTest {
         client.respondFrom(fetchResponse(partition, 0, 5), node);
         client.poll(0, time.milliseconds());
 
-        time.sleep(heartbeatIntervalMs);
-
         client.prepareResponseFrom(fetchResponse(partition, 5, 0), node);
         final AtomicBoolean heartbeatReceived = new AtomicBoolean(false);
         client.prepareResponseFrom(new MockClient.RequestMatcher() {
@@ -450,6 +449,9 @@ public class KafkaConsumerTest {
             }
         }, new HeartbeatResponse(Errors.NONE.code()).toStruct(), coordinator);
 
+        time.sleep(heartbeatIntervalMs);
+        Thread.sleep(heartbeatIntervalMs);
+
         consumer.poll(0);
 
         assertTrue(heartbeatReceived.get());
@@ -459,6 +461,7 @@ public class KafkaConsumerTest {
     public void verifyNoCoordinatorLookupForManualAssignmentWithSeek() {
         String topic = "topic";
         final TopicPartition partition = new TopicPartition(topic, 0);
+        int rebalanceTimeoutMs = 60000;
         int sessionTimeoutMs = 3000;
         int heartbeatIntervalMs = 2000;
         int autoCommitIntervalMs = 1000;
@@ -473,7 +476,7 @@ public class KafkaConsumerTest {
         PartitionAssignor assignor = new RoundRobinAssignor();
 
         final KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor,
-                sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
+                rebalanceTimeoutMs, sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
         consumer.assign(Arrays.asList(partition));
         consumer.seekToBeginning(Arrays.asList(partition));
 
@@ -496,6 +499,7 @@ public class KafkaConsumerTest {
         long offset1 = 10000;
         long offset2 = 20000;
 
+        int rebalanceTimeoutMs = 6000;
         int sessionTimeoutMs = 3000;
         int heartbeatIntervalMs = 2000;
         int autoCommitIntervalMs = 1000;
@@ -510,7 +514,7 @@ public class KafkaConsumerTest {
         PartitionAssignor assignor = new RoundRobinAssignor();
 
         final KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor,
-                sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
+                rebalanceTimeoutMs, sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
         consumer.assign(Arrays.asList(partition1));
 
         // lookup coordinator
@@ -541,6 +545,7 @@ public class KafkaConsumerTest {
         String topic = "topic";
         final TopicPartition partition = new TopicPartition(topic, 0);
 
+        int rebalanceTimeoutMs = 60000;
         int sessionTimeoutMs = 30000;
         int heartbeatIntervalMs = 3000;
 
@@ -558,7 +563,7 @@ public class KafkaConsumerTest {
         PartitionAssignor assignor = new RoundRobinAssignor();
 
         final KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor,
-                sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
+                rebalanceTimeoutMs, sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
         consumer.subscribe(Arrays.asList(topic), new ConsumerRebalanceListener() {
             @Override
             public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
@@ -619,6 +624,7 @@ public class KafkaConsumerTest {
         String topic = "topic";
         final TopicPartition partition = new TopicPartition(topic, 0);
 
+        int rebalanceTimeoutMs = 60000;
         int sessionTimeoutMs = 30000;
         int heartbeatIntervalMs = 3000;
 
@@ -636,7 +642,7 @@ public class KafkaConsumerTest {
         PartitionAssignor assignor = new RoundRobinAssignor();
 
         final KafkaConsumer<String, String> consumer = newConsumer(time, client, metadata, assignor,
-                sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
+                rebalanceTimeoutMs, sessionTimeoutMs, heartbeatIntervalMs, autoCommitIntervalMs);
         consumer.subscribe(Arrays.asList(topic), new ConsumerRebalanceListener() {
             @Override
             public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
@@ -725,6 +731,7 @@ public class KafkaConsumerTest {
                                                       KafkaClient client,
                                                       Metadata metadata,
                                                       PartitionAssignor assignor,
+                                                      int rebalanceTimeoutMs,
                                                       int sessionTimeoutMs,
                                                       int heartbeatIntervalMs,
                                                       int autoCommitIntervalMs) {
@@ -757,6 +764,7 @@ public class KafkaConsumerTest {
         ConsumerCoordinator consumerCoordinator = new ConsumerCoordinator(
                 consumerClient,
                 groupId,
+                rebalanceTimeoutMs,
                 sessionTimeoutMs,
                 heartbeatIntervalMs,
                 assignors,
@@ -800,6 +808,9 @@ public class KafkaConsumerTest {
                 metrics,
                 subscriptions,
                 metadata,
+                autoCommitEnabled,
+                autoCommitIntervalMs,
+                heartbeatIntervalMs,
                 retryBackoffMs,
                 requestTimeoutMs);
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
index 7a05eb1..77f9df5 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
@@ -16,15 +16,20 @@
  **/
 package org.apache.kafka.clients.consumer.internals;
 
+import org.apache.kafka.clients.ClientRequest;
 import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.protocol.types.Struct;
 import org.apache.kafka.common.requests.GroupCoordinatorResponse;
+import org.apache.kafka.common.requests.HeartbeatResponse;
 import org.apache.kafka.common.requests.JoinGroupRequest;
+import org.apache.kafka.common.requests.JoinGroupResponse;
+import org.apache.kafka.common.requests.SyncGroupResponse;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.test.TestUtils;
@@ -37,12 +42,15 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 public class AbstractCoordinatorTest {
 
     private static final ByteBuffer EMPTY_DATA = ByteBuffer.wrap(new byte[0]);
-    private static final int SESSION_TIMEOUT_MS = 30000;
+    private static final int REBALANCE_TIMEOUT_MS = 60000;
+    private static final int SESSION_TIMEOUT_MS = 10000;
     private static final int HEARTBEAT_INTERVAL_MS = 3000;
     private static final long RETRY_BACKOFF_MS = 100;
     private static final long REQUEST_TIMEOUT_MS = 40000;
@@ -77,8 +85,8 @@ public class AbstractCoordinatorTest {
 
     @Test
     public void testCoordinatorDiscoveryBackoff() {
-        mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
-        mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE.code()));
+        mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
+        mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
 
         // blackout the coordinator for 50 milliseconds to simulate a disconnect.
         // after backing off, we should be able to connect.
@@ -91,17 +99,65 @@ public class AbstractCoordinatorTest {
         assertTrue(endTime - initialTime >= RETRY_BACKOFF_MS);
     }
 
-    private Struct groupCoordinatorResponse(Node node, short error) {
-        GroupCoordinatorResponse response = new GroupCoordinatorResponse(error, node);
+    @Test
+    public void testUncaughtExceptionInHeartbeatThread() throws Exception {
+        mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
+        mockClient.prepareResponse(joinGroupFollowerResponse(1, "memberId", "leaderId", Errors.NONE));
+        mockClient.prepareResponse(syncGroupResponse(Errors.NONE));
+
+
+        final RuntimeException e = new RuntimeException();
+
+        // raise the error when the background thread tries to send a heartbeat
+        mockClient.prepareResponse(new MockClient.RequestMatcher() {
+            @Override
+            public boolean matches(ClientRequest request) {
+                if (request.request().header().apiKey() == ApiKeys.HEARTBEAT.id)
+                    throw e;
+                return false;
+            }
+        }, heartbeatResponse(Errors.UNKNOWN));
+
+        try {
+            coordinator.ensureActiveGroup();
+            mockTime.sleep(HEARTBEAT_INTERVAL_MS);
+            synchronized (coordinator) {
+                coordinator.notify();
+            }
+            Thread.sleep(100);
+
+            coordinator.pollHeartbeat(mockTime.milliseconds());
+            fail("Expected pollHeartbeat to raise an error");
+        } catch (RuntimeException exception) {
+            assertEquals(exception, e);
+        }
+    }
+
+    private Struct groupCoordinatorResponse(Node node, Errors error) {
+        GroupCoordinatorResponse response = new GroupCoordinatorResponse(error.code(), node);
         return response.toStruct();
     }
 
+    private Struct heartbeatResponse(Errors error) {
+        HeartbeatResponse response = new HeartbeatResponse(error.code());
+        return response.toStruct();
+    }
+
+    private Struct joinGroupFollowerResponse(int generationId, String memberId, String leaderId, Errors error) {
+        return new JoinGroupResponse(error.code(), generationId, "dummy-subprotocol", memberId, leaderId,
+                Collections.<String, ByteBuffer>emptyMap()).toStruct();
+    }
+
+    private Struct syncGroupResponse(Errors error) {
+        return new SyncGroupResponse(error.code(), ByteBuffer.allocate(0)).toStruct();
+    }
+
     public class DummyCoordinator extends AbstractCoordinator {
 
         public DummyCoordinator(ConsumerNetworkClient client,
                                 Metrics metrics,
                                 Time time) {
-            super(client, GROUP_ID, SESSION_TIMEOUT_MS, HEARTBEAT_INTERVAL_MS, metrics,
+            super(client, GROUP_ID, REBALANCE_TIMEOUT_MS, SESSION_TIMEOUT_MS, HEARTBEAT_INTERVAL_MS, metrics,
                     METRIC_GROUP_PREFIX, time, RETRY_BACKOFF_MS);
         }
 


[4/4] kafka git commit: KAFKA-3888: send consumer heartbeats from a background thread (KIP-62)

Posted by gu...@apache.org.
KAFKA-3888: send consumer heartbeats from a background thread (KIP-62)

Author: Jason Gustafson <ja...@confluent.io>

Reviewers: Ewen Cheslack-Postava <ew...@confluent.io>, Ismael Juma <is...@juma.me.uk>, Guozhang Wang <wa...@gmail.com>

Closes #1627 from hachikuji/KAFKA-3888


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/40b1dd3f
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/40b1dd3f
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/40b1dd3f

Branch: refs/heads/trunk
Commit: 40b1dd3f495a59abef8a0cba5450526994c92c04
Parents: 19997ed
Author: Jason Gustafson <ja...@confluent.io>
Authored: Wed Aug 17 11:50:04 2016 -0700
Committer: Guozhang Wang <wa...@gmail.com>
Committed: Wed Aug 17 11:50:04 2016 -0700

----------------------------------------------------------------------
 .../clients/consumer/CommitFailedException.java |   9 +-
 .../kafka/clients/consumer/ConsumerConfig.java  |  39 +-
 .../kafka/clients/consumer/KafkaConsumer.java   |  78 ++-
 .../consumer/internals/AbstractCoordinator.java | 532 ++++++++++++-------
 .../consumer/internals/ConsumerCoordinator.java | 264 +++++----
 .../internals/ConsumerNetworkClient.java        | 267 ++++++----
 .../clients/consumer/internals/DelayedTask.java |  24 -
 .../consumer/internals/DelayedTaskQueue.java    |  96 ----
 .../clients/consumer/internals/Fetcher.java     |  21 +-
 .../clients/consumer/internals/Heartbeat.java   |  56 +-
 .../consumer/internals/RequestFuture.java       |   2 +-
 .../apache/kafka/common/protocol/Protocol.java  |  25 +-
 .../kafka/common/requests/JoinGroupRequest.java |  49 +-
 .../common/requests/JoinGroupResponse.java      |  15 +-
 .../common/requests/OffsetFetchResponse.java    |   2 -
 .../clients/consumer/KafkaConsumerTest.java     |  49 +-
 .../internals/AbstractCoordinatorTest.java      |  68 ++-
 .../internals/ConsumerCoordinatorTest.java      | 179 ++++---
 .../internals/ConsumerNetworkClientTest.java    |  26 +-
 .../internals/DelayedTaskQueueTest.java         |  89 ----
 .../clients/consumer/internals/FetcherTest.java |   1 +
 .../consumer/internals/HeartbeatTest.java       |   6 +-
 .../common/requests/RequestResponseTest.java    |  14 +-
 .../runtime/distributed/DistributedConfig.java  |  29 +-
 .../runtime/distributed/WorkerCoordinator.java  |  38 +-
 .../runtime/distributed/WorkerGroupMember.java  |  15 +-
 .../distributed/WorkerCoordinatorTest.java      |   2 +
 core/src/main/scala/kafka/api/ApiVersion.scala  |  11 +-
 .../kafka/coordinator/GroupCoordinator.scala    |  90 ++--
 .../scala/kafka/coordinator/GroupMetadata.scala |   4 +-
 .../coordinator/GroupMetadataManager.scala      | 150 ++++--
 .../kafka/coordinator/MemberMetadata.scala      |   1 +
 .../src/main/scala/kafka/server/KafkaApis.scala |   6 +-
 .../kafka/api/AuthorizerIntegrationTest.scala   |   2 +-
 .../kafka/api/BaseConsumerTest.scala            | 170 +-----
 .../kafka/api/ConsumerBounceTest.scala          |  13 +-
 .../kafka/api/PlaintextConsumerTest.scala       | 223 ++++++--
 .../SaslPlainSslEndToEndAuthorizationTest.scala |   1 -
 .../GroupCoordinatorResponseTest.scala          | 213 +++++---
 .../coordinator/GroupMetadataManagerTest.scala  |  11 +-
 .../kafka/coordinator/GroupMetadataTest.scala   |  59 +-
 .../kafka/coordinator/MemberMetadataTest.scala  |  16 +-
 .../unit/kafka/utils/timer/MockTimer.scala      |   2 +-
 43 files changed, 1686 insertions(+), 1281 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/CommitFailedException.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/CommitFailedException.java b/clients/src/main/java/org/apache/kafka/clients/consumer/CommitFailedException.java
index 26ef48e..5695be8 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/CommitFailedException.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/CommitFailedException.java
@@ -28,7 +28,12 @@ public class CommitFailedException extends KafkaException {
 
     private static final long serialVersionUID = 1L;
 
-    public CommitFailedException(String message) {
-        super(message);
+    public CommitFailedException() {
+        super("Commit cannot be completed since the group has already " +
+                "rebalanced and assigned the partitions to another member. This means that the time " +
+                "between subsequent calls to poll() was longer than the configured max.poll.interval.ms, " +
+                "which typically implies that the poll loop is spending too much time message processing. " +
+                "You can address this either by increasing the session timeout or by reducing the maximum " +
+                "size of batches returned in poll() with max.poll.records.");
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java
index de10bed..509c3a1 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java
@@ -48,24 +48,33 @@ public class ConsumerConfig extends AbstractConfig {
     public static final String MAX_POLL_RECORDS_CONFIG = "max.poll.records";
     private static final String MAX_POLL_RECORDS_DOC = "The maximum number of records returned in a single call to poll().";
 
+    /** <code>max.poll.interval.ms</code> */
+    public static final String MAX_POLL_INTERVAL_MS_CONFIG = "max.poll.interval.ms";
+    private static final String MAX_POLL_INTERVAL_MS_DOC = "The maximum delay between invocations of poll() when using " +
+            "consumer group management. This places an upper bound on the amount of time that the consumer can be idle " +
+            "before fetching more records. If poll() is not called before expiration of this timeout, then the consumer " +
+            "is considered failed and the group will rebalance in order to reassign the partitions to another member. ";
+
     /**
      * <code>session.timeout.ms</code>
      */
     public static final String SESSION_TIMEOUT_MS_CONFIG = "session.timeout.ms";
-    private static final String SESSION_TIMEOUT_MS_DOC = "The timeout used to detect failures when using Kafka's " +
-            "group management facilities. When a consumer's heartbeat is not received within the session timeout, " +
-            "the broker will mark the consumer as failed and rebalance the group. Since heartbeats are sent only " +
-            "when poll() is invoked, a higher session timeout allows more time for message processing in the consumer's " +
-            "poll loop at the cost of a longer time to detect hard failures. See also <code>" + MAX_POLL_RECORDS_CONFIG + "</code> for " +
-            "another option to control the processing time in the poll loop. Note that the value must be in the " +
-            "allowable range as configured in the broker configuration by <code>group.min.session.timeout.ms</code> " +
+    private static final String SESSION_TIMEOUT_MS_DOC = "The timeout used to detect consumer failures when using " +
+            "Kafka's group management facility. The consumer sends periodic heartbeats to indicate its liveness " +
+            "to the broker. If no heartbeats are received by the broker before the expiration of this session timeout, " +
+            "then the broker will remove this consumer from the group and initiate a rebalance. Note that the value " +
+            "must be in the allowable range as configured in the broker configuration by <code>group.min.session.timeout.ms</code> " +
             "and <code>group.max.session.timeout.ms</code>.";
 
     /**
      * <code>heartbeat.interval.ms</code>
      */
     public static final String HEARTBEAT_INTERVAL_MS_CONFIG = "heartbeat.interval.ms";
-    private static final String HEARTBEAT_INTERVAL_MS_DOC = "The expected time between heartbeats to the consumer coordinator when using Kafka's group management facilities. Heartbeats are used to ensure that the consumer's session stays active and to facilitate rebalancing when new consumers join or leave the group. The value must be set lower than <code>session.timeout.ms</code>, but typically should be set no higher than 1/3 of that value. It can be adjusted even lower to control the expected time for normal rebalances.";
+    private static final String HEARTBEAT_INTERVAL_MS_DOC = "The expected time between heartbeats to the consumer " +
+            "coordinator when using Kafka's group management facilities. Heartbeats are used to ensure that the " +
+            "consumer's session stays active and to facilitate rebalancing when new consumers join or leave the group. " +
+            "The value must be set lower than <code>session.timeout.ms</code>, but typically should be set no higher " +
+            "than 1/3 of that value. It can be adjusted even lower to control the expected time for normal rebalances.";
 
     /**
      * <code>bootstrap.servers</code>
@@ -196,7 +205,7 @@ public class ConsumerConfig extends AbstractConfig {
                                 .define(GROUP_ID_CONFIG, Type.STRING, "", Importance.HIGH, GROUP_ID_DOC)
                                 .define(SESSION_TIMEOUT_MS_CONFIG,
                                         Type.INT,
-                                        30000,
+                                        10000,
                                         Importance.HIGH,
                                         SESSION_TIMEOUT_MS_DOC)
                                 .define(HEARTBEAT_INTERVAL_MS_CONFIG,
@@ -221,7 +230,7 @@ public class ConsumerConfig extends AbstractConfig {
                                         Importance.MEDIUM,
                                         ENABLE_AUTO_COMMIT_DOC)
                                 .define(AUTO_COMMIT_INTERVAL_MS_CONFIG,
-                                        Type.LONG,
+                                        Type.INT,
                                         5000,
                                         atLeast(0),
                                         Importance.LOW,
@@ -311,7 +320,7 @@ public class ConsumerConfig extends AbstractConfig {
                                         VALUE_DESERIALIZER_CLASS_DOC)
                                 .define(REQUEST_TIMEOUT_MS_CONFIG,
                                         Type.INT,
-                                        40 * 1000,
+                                        305000, // chosen to be higher than the default of max.poll.interval.ms
                                         atLeast(0),
                                         Importance.MEDIUM,
                                         REQUEST_TIMEOUT_MS_DOC)
@@ -328,10 +337,16 @@ public class ConsumerConfig extends AbstractConfig {
                                         INTERCEPTOR_CLASSES_DOC)
                                 .define(MAX_POLL_RECORDS_CONFIG,
                                         Type.INT,
-                                        Integer.MAX_VALUE,
+                                        500,
                                         atLeast(1),
                                         Importance.MEDIUM,
                                         MAX_POLL_RECORDS_DOC)
+                                .define(MAX_POLL_INTERVAL_MS_CONFIG,
+                                        Type.INT,
+                                        300000,
+                                        atLeast(1),
+                                        Importance.MEDIUM,
+                                        MAX_POLL_INTERVAL_MS_DOC)
                                 .define(EXCLUDE_INTERNAL_TOPICS_CONFIG,
                                         Type.BOOLEAN,
                                         DEFAULT_EXCLUDE_INTERNAL_TOPICS,

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
index 522cfde..ef91302 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
@@ -137,32 +137,31 @@ import java.util.regex.Pattern;
  * After subscribing to a set of topics, the consumer will automatically join the group when {@link #poll(long)} is
  * invoked. The poll API is designed to ensure consumer liveness. As long as you continue to call poll, the consumer
  * will stay in the group and continue to receive messages from the partitions it was assigned. Underneath the covers,
- * the poll API sends periodic heartbeats to the server; when you stop calling poll (perhaps because an exception was thrown),
- * then no heartbeats will be sent. If a period of the configured <i>session timeout</i> elapses before the server
- * has received a heartbeat, then the consumer will be kicked out of the group and its partitions will be reassigned.
- * This is designed to prevent situations where the consumer has failed, yet continues to hold onto the partitions
- * it was assigned (thus preventing active consumers in the group from taking them). To stay in the group, you
- * have to prove you are still alive by calling poll.
+ * the consumer sends periodic heartbeats to the server. If the consumer crashes or is unable to send heartbeats for
+ * a duration of <code>session.timeout.ms</code>, then the consumer will be considered dead and its partitions will
+ * be reassigned. It is also possible that the consumer could encounter a "livelock" situation where it is continuing
+ * to send heartbeats, but no progress is being made. To prevent the consumer from holding onto its partitions
+ * indefinitely in this case, we provide a liveness detection mechanism: basically if you don't call poll at least
+ * as frequently as the configured <code>poll.interval.ms</code>, then the client will proactively leave the group
+ * so that another consumer can take over its partitions. So to stay in the group, you must continue to call poll
  * <p>
  * The implication of this design is that message processing time in the poll loop must be bounded so that
- * heartbeats can be sent before expiration of the session timeout. What typically happens when processing time
- * exceeds the session timeout is that the consumer won't be able to commit offsets for any of the processed records.
- * For example, this is indicated by a {@link CommitFailedException} thrown from {@link #commitSync()}. This
- * guarantees that only active members of the group are allowed to commit offsets. If the consumer
- * has been kicked out of the group, then its partitions will have been assigned to another member, which will be
- * committing its own offsets as it handles new records. This gives offset commits an isolation guarantee.
+ * you always ensure that poll() is called at least once every poll interval. If not, then the consumer leaves
+ * the group, which typically results in an offset commit failure when the processing of the polled records
+ * finally completes (this is indicated by a {@link CommitFailedException} thrown from {@link #commitSync()}).
+ * This is a safety mechanism which guarantees that only active members of the group are able to commit offsets.
+ * If the consumer has been kicked out of the group, then its partitions will have been assigned to another member,
+ * which will be committing its own offsets as it handles new records. This gives offset commits an isolation guarantee.
  * <p>
- * The consumer provides two configuration settings to control this behavior:
+ * The consumer provides two configuration settings to control the behavior of the poll loop:
  * <ol>
- *     <li><code>session.timeout.ms</code>: By increasing the session timeout, you can give the consumer more
- *     time to handle a batch of records returned from {@link #poll(long)}. The only drawback is that it
- *     will take longer for the server to detect hard consumer failures, which can cause a delay before
- *     a rebalance can be completed. However, clean shutdown with {@link #close()} is not impacted since
- *     the consumer will send an explicit message to the server to leave the group and cause an immediate
- *     rebalance.</li>
- *     <li><code>max.poll.records</code>: Processing time in the poll loop is typically proportional to the number
- *     of records processed, so it's natural to want to set a limit on the number of records handled at once.
- *     This setting provides that. By default, there is essentially no limit.</li>
+ *     <li><code>max.poll.interval.ms</code>: By increasing the interval between expected polls, you can give
+ *     the consumer more time to handle a batch of records returned from {@link #poll(long)}. The drawback
+ *     is that increasing this value may delay a group rebalance since the consumer will only join the rebalance
+ *     inside the call to poll.</li>
+ *     <li><code>max.poll.records</code>: Use this setting to limit the total records returned from a single
+ *     call to poll. This can make it easier to predict the maximum that must be handled within each poll
+ *     interval.</li>
  * </ol>
  * <p>
  * For use cases where message processing time varies unpredictably, neither of these options may be viable.
@@ -187,7 +186,6 @@ import java.util.regex.Pattern;
  *     props.put(&quot;group.id&quot;, &quot;test&quot;);
  *     props.put(&quot;enable.auto.commit&quot;, &quot;true&quot;);
  *     props.put(&quot;auto.commit.interval.ms&quot;, &quot;1000&quot;);
- *     props.put(&quot;session.timeout.ms&quot;, &quot;30000&quot;);
  *     props.put(&quot;key.deserializer&quot;, &quot;org.apache.kafka.common.serialization.StringDeserializer&quot;);
  *     props.put(&quot;value.deserializer&quot;, &quot;org.apache.kafka.common.serialization.StringDeserializer&quot;);
  *     KafkaConsumer&lt;String, String&gt; consumer = new KafkaConsumer&lt;&gt;(props);
@@ -210,13 +208,6 @@ import java.util.regex.Pattern;
  * In this example the client is subscribing to the topics <i>foo</i> and <i>bar</i> as part of a group of consumers
  * called <i>test</i> as described above.
  * <p>
- * The broker will automatically detect failed processes in the <i>test</i> group by using a heartbeat mechanism. The
- * consumer will automatically ping the cluster periodically, which lets the cluster know that it is alive. Note that
- * the consumer is single-threaded, so periodic heartbeats can only be sent when {@link #poll(long)} is called. As long as
- * the consumer is able to do this it is considered alive and retains the right to consume from the partitions assigned
- * to it. If it stops heartbeating by failing to call {@link #poll(long)} for a period of time longer than <code>session.timeout.ms</code>
- * then it will be considered dead and its partitions will be assigned to another process.
- * <p>
  * The deserializer settings specify how to turn bytes into objects. For example, by specifying string deserializers, we
  * are saying that our record's key and value will just be simple strings.
  *
@@ -242,7 +233,6 @@ import java.util.regex.Pattern;
  *     props.put(&quot;bootstrap.servers&quot;, &quot;localhost:9092&quot;);
  *     props.put(&quot;group.id&quot;, &quot;test&quot;);
  *     props.put(&quot;enable.auto.commit&quot;, &quot;false&quot;);
- *     props.put(&quot;session.timeout.ms&quot;, &quot;30000&quot;);
  *     props.put(&quot;key.deserializer&quot;, &quot;org.apache.kafka.common.serialization.StringDeserializer&quot;);
  *     props.put(&quot;value.deserializer&quot;, &quot;org.apache.kafka.common.serialization.StringDeserializer&quot;);
  *     KafkaConsumer&lt;String, String&gt; consumer = new KafkaConsumer&lt;&gt;(props);
@@ -645,6 +635,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
             this.interceptors = interceptorList.isEmpty() ? null : new ConsumerInterceptors<>(interceptorList);
             this.coordinator = new ConsumerCoordinator(this.client,
                     config.getString(ConsumerConfig.GROUP_ID_CONFIG),
+                    config.getInt(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG),
                     config.getInt(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG),
                     config.getInt(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG),
                     assignors,
@@ -656,7 +647,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                     retryBackoffMs,
                     new ConsumerCoordinator.DefaultOffsetCommitCallback(),
                     config.getBoolean(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG),
-                    config.getLong(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG),
+                    config.getInt(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG),
                     this.interceptors,
                     config.getBoolean(ConsumerConfig.EXCLUDE_INTERNAL_TOPICS_CONFIG));
             if (keyDeserializer == null) {
@@ -715,6 +706,9 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                   Metrics metrics,
                   SubscriptionState subscriptions,
                   Metadata metadata,
+                  boolean autoCommitEnabled,
+                  int autoCommitIntervalMs,
+                  int heartbeatIntervalMs,
                   long retryBackoffMs,
                   long requestTimeoutMs) {
         this.clientId = clientId;
@@ -970,7 +964,6 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                     //
                     // NOTE: since the consumed position has already been updated, we must not allow
                     // wakeups or any other errors to be triggered prior to returning the fetched records.
-                    // Additionally, pollNoWakeup does not allow automatic commits to get triggered.
                     fetcher.sendFetches();
                     client.pollNoWakeup();
 
@@ -997,30 +990,23 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
      * @return The fetched records (may be empty)
      */
     private Map<TopicPartition, List<ConsumerRecord<K, V>>> pollOnce(long timeout) {
-        // ensure we have partitions assigned if we expect to
-        if (subscriptions.partitionsAutoAssigned())
-            coordinator.ensurePartitionAssignment();
+        coordinator.poll(time.milliseconds());
 
         // fetch positions if we have partitions we're subscribed to that we
         // don't know the offset for
         if (!subscriptions.hasAllFetchPositions())
             updateFetchPositions(this.subscriptions.missingFetchPositions());
 
-        long now = time.milliseconds();
-
-        // execute delayed tasks (e.g. autocommits and heartbeats) prior to fetching records
-        client.executeDelayedTasks(now);
-
-        // init any new fetches (won't resend pending fetches)
+        // if data is available already, return it immediately
         Map<TopicPartition, List<ConsumerRecord<K, V>>> records = fetcher.fetchedRecords();
-
-        // if data is available already, e.g. from a previous network client poll() call to commit,
-        // then just return it immediately
         if (!records.isEmpty())
             return records;
 
+        // send any new fetches (won't resend pending fetches)
         fetcher.sendFetches();
-        client.poll(timeout, now);
+
+        long now = time.milliseconds();
+        client.poll(Math.min(coordinator.timeToNextPoll(now), timeout), now);
         return fetcher.fetchedRecords();
     }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
index e957856..690df26 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
@@ -20,6 +20,7 @@ import org.apache.kafka.common.errors.GroupAuthorizationException;
 import org.apache.kafka.common.errors.GroupCoordinatorNotAvailableException;
 import org.apache.kafka.common.errors.IllegalGenerationException;
 import org.apache.kafka.common.errors.RebalanceInProgressException;
+import org.apache.kafka.common.errors.RetriableException;
 import org.apache.kafka.common.errors.UnknownMemberIdException;
 import org.apache.kafka.common.metrics.Measurable;
 import org.apache.kafka.common.metrics.MetricConfig;
@@ -53,6 +54,7 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
 
 /**
  * AbstractCoordinator implements group management for a single group member by interacting with
@@ -77,26 +79,38 @@ import java.util.concurrent.TimeUnit;
  * by the leader in {@link #performAssignment(String, String, Map)} and becomes available to members in
  * {@link #onJoinComplete(int, String, String, ByteBuffer)}.
  *
+ * Note on locking: this class shares state between the caller and a background thread which is
+ * used for sending heartbeats after the client has joined the group. All mutable state as well as
+ * state transitions are protected with the class's monitor. Generally this means acquiring the lock
+ * before reading or writing the state of the group (e.g. generation, memberId) and holding the lock
+ * when sending a request that affects the state of the group (e.g. JoinGroup, LeaveGroup).
  */
 public abstract class AbstractCoordinator implements Closeable {
 
     private static final Logger log = LoggerFactory.getLogger(AbstractCoordinator.class);
 
-    private final Heartbeat heartbeat;
-    private final HeartbeatTask heartbeatTask;
+    private enum MemberState {
+        UNJOINED,    // the client is not part of a group
+        REBALANCING, // the client has begun rebalancing
+        STABLE,      // the client has joined and is sending heartbeats
+    }
+
+    private final int rebalanceTimeoutMs;
     private final int sessionTimeoutMs;
     private final GroupCoordinatorMetrics sensors;
+    private final Heartbeat heartbeat;
     protected final String groupId;
     protected final ConsumerNetworkClient client;
     protected final Time time;
     protected final long retryBackoffMs;
 
-    private boolean needsJoinPrepare = true;
+    private HeartbeatThread heartbeatThread = null;
     private boolean rejoinNeeded = true;
-    protected Node coordinator;
-    protected String memberId;
-    protected String protocol;
-    protected int generation;
+    private boolean needsJoinPrepare = true;
+    private MemberState state = MemberState.UNJOINED;
+    private RequestFuture<ByteBuffer> joinFuture = null;
+    private Node coordinator = null;
+    private Generation generation = Generation.NO_GENERATION;
 
     private RequestFuture<Void> findCoordinatorFuture = null;
 
@@ -105,6 +119,7 @@ public abstract class AbstractCoordinator implements Closeable {
      */
     public AbstractCoordinator(ConsumerNetworkClient client,
                                String groupId,
+                               int rebalanceTimeoutMs,
                                int sessionTimeoutMs,
                                int heartbeatIntervalMs,
                                Metrics metrics,
@@ -113,19 +128,16 @@ public abstract class AbstractCoordinator implements Closeable {
                                long retryBackoffMs) {
         this.client = client;
         this.time = time;
-        this.generation = OffsetCommitRequest.DEFAULT_GENERATION_ID;
-        this.memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID;
         this.groupId = groupId;
-        this.coordinator = null;
+        this.rebalanceTimeoutMs = rebalanceTimeoutMs;
         this.sessionTimeoutMs = sessionTimeoutMs;
-        this.heartbeat = new Heartbeat(this.sessionTimeoutMs, heartbeatIntervalMs, time.milliseconds());
-        this.heartbeatTask = new HeartbeatTask();
+        this.heartbeat = new Heartbeat(sessionTimeoutMs, heartbeatIntervalMs, rebalanceTimeoutMs, retryBackoffMs);
         this.sensors = new GroupCoordinatorMetrics(metrics, metricGrpPrefix);
         this.retryBackoffMs = retryBackoffMs;
     }
 
     /**
-     * Unique identifier for the class of protocols implements (e.g. "consumer" or "connect").
+     * Unique identifier for the class of supported protocols (e.g. "consumer" or "connect").
      * @return Non-null protocol type name
      */
     protected abstract String protocolType();
@@ -175,7 +187,7 @@ public abstract class AbstractCoordinator implements Closeable {
     /**
      * Block until the coordinator for this group is known and is ready to receive requests.
      */
-    public void ensureCoordinatorReady() {
+    public synchronized void ensureCoordinatorReady() {
         while (coordinatorUnknown()) {
             RequestFuture<Void> future = lookupCoordinator();
             client.poll(future);
@@ -216,14 +228,44 @@ public abstract class AbstractCoordinator implements Closeable {
      * Check whether the group should be rejoined (e.g. if metadata changes)
      * @return true if it should, false otherwise
      */
-    protected boolean needRejoin() {
+    protected synchronized boolean needRejoin() {
         return rejoinNeeded;
     }
 
     /**
+     * Check the status of the heartbeat thread (if it is active) and indicate the liveness
+     * of the client. This must be called periodically after joining with {@link #ensureActiveGroup()}
+     * to ensure that the member stays in the group. If an interval of time longer than the
+     * provided rebalance timeout expires without calling this method, then the client will proactively
+     * leave the group.
+     * @param now current time in milliseconds
+     * @throws RuntimeException for unexpected errors raised from the heartbeat thread
+     */
+    protected synchronized void pollHeartbeat(long now) {
+        if (heartbeatThread != null) {
+            if (heartbeatThread.hasFailed()) {
+                // set the heartbeat thread to null and raise an exception. If the user catches it,
+                // the next call to ensureActiveGroup() will spawn a new heartbeat thread.
+                RuntimeException cause = heartbeatThread.failureCause();
+                heartbeatThread = null;
+                throw cause;
+            }
+
+            heartbeat.poll(now);
+        }
+    }
+
+    protected synchronized long timeToNextHeartbeat(long now) {
+        // if we have not joined the group, we don't need to send heartbeats
+        if (state == MemberState.UNJOINED)
+            return Long.MAX_VALUE;
+        return heartbeat.timeToNextHeartbeat(now);
+    }
+
+    /**
      * Ensure that the group is active (i.e. joined and synced)
      */
-    public void ensureActiveGroup() {
+    public synchronized void ensureActiveGroup() {
         // always ensure that the coordinator is ready because we may have been disconnected
         // when sending heartbeats and does not necessarily require us to rejoin the group.
         ensureCoordinatorReady();
@@ -231,11 +273,18 @@ public abstract class AbstractCoordinator implements Closeable {
         if (!needRejoin())
             return;
 
+        // call onJoinPrepare if needed. We set a flag to make sure that we do not call it a second
+        // time if the client is woken up before a pending rebalance completes.
         if (needsJoinPrepare) {
-            onJoinPrepare(generation, memberId);
+            onJoinPrepare(generation.generationId, generation.memberId);
             needsJoinPrepare = false;
         }
 
+        if (heartbeatThread == null) {
+            heartbeatThread = new HeartbeatThread();
+            heartbeatThread.start();
+        }
+
         while (needRejoin()) {
             ensureCoordinatorReady();
 
@@ -246,23 +295,41 @@ public abstract class AbstractCoordinator implements Closeable {
                 continue;
             }
 
-            RequestFuture<ByteBuffer> future = sendJoinGroupRequest();
-            future.addListener(new RequestFutureListener<ByteBuffer>() {
-                @Override
-                public void onSuccess(ByteBuffer value) {
-                    // handle join completion in the callback so that the callback will be invoked
-                    // even if the consumer is woken up before finishing the rebalance
-                    onJoinComplete(generation, memberId, protocol, value);
-                    needsJoinPrepare = true;
-                    heartbeatTask.reset();
-                }
+            // we store the join future in case we are woken up by the user after beginning the
+            // rebalance in the call to poll below. This ensures that we do not mistakenly attempt
+            // to rejoin before the pending rebalance has completed.
+            if (joinFuture == null) {
+                state = MemberState.REBALANCING;
+                joinFuture = sendJoinGroupRequest();
+                joinFuture.addListener(new RequestFutureListener<ByteBuffer>() {
+                    @Override
+                    public void onSuccess(ByteBuffer value) {
+                        // handle join completion in the callback so that the callback will be invoked
+                        // even if the consumer is woken up before finishing the rebalance
+                        synchronized (AbstractCoordinator.this) {
+                            log.info("Successfully joined group {} with generation {}", groupId, generation.generationId);
+                            joinFuture = null;
+                            state = MemberState.STABLE;
+                            needsJoinPrepare = true;
+                            heartbeatThread.enable();
+                        }
 
-                @Override
-                public void onFailure(RuntimeException e) {
-                    // we handle failures below after the request finishes. if the join completes
-                    // after having been woken up, the exception is ignored and we will rejoin
-                }
-            });
+                        onJoinComplete(generation.generationId, generation.memberId, generation.protocol, value);
+                    }
+
+                    @Override
+                    public void onFailure(RuntimeException e) {
+                        // we handle failures below after the request finishes. if the join completes
+                        // after having been woken up, the exception is ignored and we will rejoin
+                        synchronized (AbstractCoordinator.this) {
+                            joinFuture = null;
+                            state = MemberState.UNJOINED;
+                        }
+                    }
+                });
+            }
+
+            RequestFuture<ByteBuffer> future = joinFuture;
             client.poll(future);
 
             if (future.failed()) {
@@ -278,63 +345,6 @@ public abstract class AbstractCoordinator implements Closeable {
         }
     }
 
-    private class HeartbeatTask implements DelayedTask {
-
-        private boolean requestInFlight = false;
-
-        public void reset() {
-            // start or restart the heartbeat task to be executed at the next chance
-            long now = time.milliseconds();
-            heartbeat.resetSessionTimeout(now);
-            client.unschedule(this);
-
-            if (!requestInFlight)
-                client.schedule(this, now);
-        }
-
-        @Override
-        public void run(final long now) {
-            if (generation < 0 || needRejoin() || coordinatorUnknown()) {
-                // no need to send the heartbeat we're not using auto-assignment or if we are
-                // awaiting a rebalance
-                return;
-            }
-
-            if (heartbeat.sessionTimeoutExpired(now)) {
-                // we haven't received a successful heartbeat in one session interval
-                // so mark the coordinator dead
-                coordinatorDead();
-                return;
-            }
-
-            if (!heartbeat.shouldHeartbeat(now)) {
-                // we don't need to heartbeat now, so reschedule for when we do
-                client.schedule(this, now + heartbeat.timeToNextHeartbeat(now));
-            } else {
-                heartbeat.sentHeartbeat(now);
-                requestInFlight = true;
-
-                RequestFuture<Void> future = sendHeartbeatRequest();
-                future.addListener(new RequestFutureListener<Void>() {
-                    @Override
-                    public void onSuccess(Void value) {
-                        requestInFlight = false;
-                        long now = time.milliseconds();
-                        heartbeat.receiveHeartbeat(now);
-                        long nextHeartbeatTime = now + heartbeat.timeToNextHeartbeat(now);
-                        client.schedule(HeartbeatTask.this, nextHeartbeatTime);
-                    }
-
-                    @Override
-                    public void onFailure(RuntimeException e) {
-                        requestInFlight = false;
-                        client.schedule(HeartbeatTask.this, time.milliseconds() + retryBackoffMs);
-                    }
-                });
-            }
-        }
-    }
-
     /**
      * Join the group and return the assignment for the next generation. This function handles both
      * JoinGroup and SyncGroup, delegating to {@link #performAssignment(String, String, Map)} if
@@ -350,7 +360,8 @@ public abstract class AbstractCoordinator implements Closeable {
         JoinGroupRequest request = new JoinGroupRequest(
                 groupId,
                 this.sessionTimeoutMs,
-                this.memberId,
+                this.rebalanceTimeoutMs,
+                this.generation.memberId,
                 protocolType(),
                 metadata());
 
@@ -359,7 +370,6 @@ public abstract class AbstractCoordinator implements Closeable {
                 .compose(new JoinGroupResponseHandler());
     }
 
-
     private class JoinGroupResponseHandler extends CoordinatorResponseHandler<JoinGroupResponse, ByteBuffer> {
 
         @Override
@@ -372,24 +382,32 @@ public abstract class AbstractCoordinator implements Closeable {
             Errors error = Errors.forCode(joinResponse.errorCode());
             if (error == Errors.NONE) {
                 log.debug("Received successful join group response for group {}: {}", groupId, joinResponse.toStruct());
-                AbstractCoordinator.this.memberId = joinResponse.memberId();
-                AbstractCoordinator.this.generation = joinResponse.generationId();
-                AbstractCoordinator.this.rejoinNeeded = false;
-                AbstractCoordinator.this.protocol = joinResponse.groupProtocol();
                 sensors.joinLatency.record(response.requestLatencyMs());
-                if (joinResponse.isLeader()) {
-                    onJoinLeader(joinResponse).chain(future);
-                } else {
-                    onJoinFollower().chain(future);
+
+                synchronized (AbstractCoordinator.this) {
+                    if (state != MemberState.REBALANCING) {
+                        // if the consumer was woken up before a rebalance completes, we may have already left
+                        // the group. In this case, we do not want to continue with the sync group.
+                        future.raise(new UnjoinedGroupException());
+                    } else {
+                        AbstractCoordinator.this.generation = new Generation(joinResponse.generationId(),
+                                joinResponse.memberId(), joinResponse.groupProtocol());
+                        AbstractCoordinator.this.rejoinNeeded = false;
+                        if (joinResponse.isLeader()) {
+                            onJoinLeader(joinResponse).chain(future);
+                        } else {
+                            onJoinFollower().chain(future);
+                        }
+                    }
                 }
             } else if (error == Errors.GROUP_LOAD_IN_PROGRESS) {
                 log.debug("Attempt to join group {} rejected since coordinator {} is loading the group.", groupId,
-                        coordinator);
+                        coordinator());
                 // backoff and retry
                 future.raise(error);
             } else if (error == Errors.UNKNOWN_MEMBER_ID) {
                 // reset the member id and retry immediately
-                AbstractCoordinator.this.memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID;
+                resetGeneration();
                 log.debug("Attempt to join group {} failed due to unknown member id.", groupId);
                 future.raise(Errors.UNKNOWN_MEMBER_ID);
             } else if (error == Errors.GROUP_COORDINATOR_NOT_AVAILABLE
@@ -415,8 +433,8 @@ public abstract class AbstractCoordinator implements Closeable {
 
     private RequestFuture<ByteBuffer> onJoinFollower() {
         // send follower's sync group with an empty assignment
-        SyncGroupRequest request = new SyncGroupRequest(groupId, generation,
-                memberId, Collections.<String, ByteBuffer>emptyMap());
+        SyncGroupRequest request = new SyncGroupRequest(groupId, generation.generationId,
+                generation.memberId, Collections.<String, ByteBuffer>emptyMap());
         log.debug("Sending follower SyncGroup for group {} to coordinator {}: {}", groupId, this.coordinator, request);
         return sendSyncGroupRequest(request);
     }
@@ -427,7 +445,7 @@ public abstract class AbstractCoordinator implements Closeable {
             Map<String, ByteBuffer> groupAssignment = performAssignment(joinResponse.leaderId(), joinResponse.groupProtocol(),
                     joinResponse.members());
 
-            SyncGroupRequest request = new SyncGroupRequest(groupId, generation, memberId, groupAssignment);
+            SyncGroupRequest request = new SyncGroupRequest(groupId, generation.generationId, generation.memberId, groupAssignment);
             log.debug("Sending leader SyncGroup for group {} to coordinator {}: {}", groupId, this.coordinator, request);
             return sendSyncGroupRequest(request);
         } catch (RuntimeException e) {
@@ -454,11 +472,11 @@ public abstract class AbstractCoordinator implements Closeable {
                            RequestFuture<ByteBuffer> future) {
             Errors error = Errors.forCode(syncResponse.errorCode());
             if (error == Errors.NONE) {
-                log.info("Successfully joined group {} with generation {}", groupId, generation);
                 sensors.syncLatency.record(response.requestLatencyMs());
                 future.complete(syncResponse.memberAssignment());
             } else {
-                AbstractCoordinator.this.rejoinNeeded = true;
+                requestRejoin();
+
                 if (error == Errors.GROUP_AUTHORIZATION_FAILED) {
                     future.raise(new GroupAuthorizationException(groupId));
                 } else if (error == Errors.REBALANCE_IN_PROGRESS) {
@@ -467,7 +485,7 @@ public abstract class AbstractCoordinator implements Closeable {
                 } else if (error == Errors.UNKNOWN_MEMBER_ID
                         || error == Errors.ILLEGAL_GENERATION) {
                     log.debug("SyncGroup for group {} failed due to {}", groupId, error);
-                    AbstractCoordinator.this.memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID;
+                    resetGeneration();
                     future.raise(error);
                 } else if (error == Errors.GROUP_COORDINATOR_NOT_AVAILABLE
                         || error == Errors.NOT_COORDINATOR_FOR_GROUP) {
@@ -499,43 +517,36 @@ public abstract class AbstractCoordinator implements Closeable {
             log.debug("Sending coordinator request for group {} to broker {}", groupId, node);
             GroupCoordinatorRequest metadataRequest = new GroupCoordinatorRequest(this.groupId);
             return client.send(node, ApiKeys.GROUP_COORDINATOR, metadataRequest)
-                    .compose(new RequestFutureAdapter<ClientResponse, Void>() {
-                        @Override
-                        public void onSuccess(ClientResponse response, RequestFuture<Void> future) {
-                            handleGroupMetadataResponse(response, future);
-                        }
-                    });
+                    .compose(new GroupCoordinatorResponseHandler());
         }
     }
 
-    private void handleGroupMetadataResponse(ClientResponse resp, RequestFuture<Void> future) {
-        log.debug("Received group coordinator response {}", resp);
+    private class GroupCoordinatorResponseHandler extends RequestFutureAdapter<ClientResponse, Void> {
+
+        @Override
+        public void onSuccess(ClientResponse resp, RequestFuture<Void> future) {
+            log.debug("Received group coordinator response {}", resp);
 
-        if (!coordinatorUnknown()) {
-            // We already found the coordinator, so ignore the request
-            future.complete(null);
-        } else {
             GroupCoordinatorResponse groupCoordinatorResponse = new GroupCoordinatorResponse(resp.responseBody());
             // use MAX_VALUE - node.id as the coordinator id to mimic separate connections
             // for the coordinator in the underlying network client layer
             // TODO: this needs to be better handled in KAFKA-1935
             Errors error = Errors.forCode(groupCoordinatorResponse.errorCode());
             if (error == Errors.NONE) {
-                this.coordinator = new Node(Integer.MAX_VALUE - groupCoordinatorResponse.node().id(),
-                        groupCoordinatorResponse.node().host(),
-                        groupCoordinatorResponse.node().port());
-
-                log.info("Discovered coordinator {} for group {}.", coordinator, groupId);
-
-                client.tryConnect(coordinator);
-
-                // start sending heartbeats only if we have a valid generation
-                if (generation > 0)
-                    heartbeatTask.reset();
+                synchronized (AbstractCoordinator.this) {
+                    AbstractCoordinator.this.coordinator = new Node(
+                            Integer.MAX_VALUE - groupCoordinatorResponse.node().id(),
+                            groupCoordinatorResponse.node().host(),
+                            groupCoordinatorResponse.node().port());
+                    log.info("Discovered coordinator {} for group {}.", coordinator, groupId);
+                    client.tryConnect(coordinator);
+                    heartbeat.resetTimeouts(time.milliseconds());
+                }
                 future.complete(null);
             } else if (error == Errors.GROUP_AUTHORIZATION_FAILED) {
                 future.raise(new GroupAuthorizationException(groupId));
             } else {
+                log.debug("Group coordinator lookup for group {} failed: {}", groupId, error.message());
                 future.raise(error);
             }
         }
@@ -546,21 +557,25 @@ public abstract class AbstractCoordinator implements Closeable {
      * @return true if the coordinator is unknown
      */
     public boolean coordinatorUnknown() {
-        if (coordinator == null)
-            return true;
+        return coordinator() == null;
+    }
 
-        if (client.connectionFailed(coordinator)) {
+    /**
+     * Get the current coordinator
+     * @return the current coordinator or null if it is unknown
+     */
+    protected synchronized Node coordinator() {
+        if (coordinator != null && client.connectionFailed(coordinator)) {
             coordinatorDead();
-            return true;
+            return null;
         }
-
-        return false;
+        return this.coordinator;
     }
 
     /**
      * Mark the current coordinator as dead.
      */
-    protected void coordinatorDead() {
+    protected synchronized void coordinatorDead() {
         if (this.coordinator != null) {
             log.info("Marking the coordinator {} dead for group {}", this.coordinator, groupId);
             client.failUnsentRequests(this.coordinator, GroupCoordinatorNotAvailableException.INSTANCE);
@@ -569,50 +584,56 @@ public abstract class AbstractCoordinator implements Closeable {
     }
 
     /**
+     * Get the current generation state if the group is stable.
+     * @return the current generation or null if the group is unjoined/rebalancing
+     */
+    protected synchronized Generation generation() {
+        if (this.state != MemberState.STABLE)
+            return null;
+        return generation;
+    }
+
+    /**
+     * Reset the generation and memberId because we have fallen out of the group.
+     */
+    protected synchronized void resetGeneration() {
+        this.generation = Generation.NO_GENERATION;
+        this.rejoinNeeded = true;
+        this.state = MemberState.UNJOINED;
+    }
+
+    protected synchronized void requestRejoin() {
+        this.rejoinNeeded = true;
+    }
+
+    /**
      * Close the coordinator, waiting if needed to send LeaveGroup.
      */
     @Override
-    public void close() {
-        // we do not need to re-enable wakeups since we are closing already
-        client.disableWakeups();
+    public synchronized void close() {
+        if (heartbeatThread != null)
+            heartbeatThread.close();
         maybeLeaveGroup();
     }
 
     /**
      * Leave the current group and reset local generation/memberId.
      */
-    public void maybeLeaveGroup() {
-        client.unschedule(heartbeatTask);
-        if (!coordinatorUnknown() && generation > 0) {
+    public synchronized void maybeLeaveGroup() {
+        if (!coordinatorUnknown() && state != MemberState.UNJOINED && generation != Generation.NO_GENERATION) {
             // this is a minimal effort attempt to leave the group. we do not
             // attempt any resending if the request fails or times out.
-            sendLeaveGroupRequest();
+            LeaveGroupRequest request = new LeaveGroupRequest(groupId, generation.memberId);
+            client.send(coordinator, ApiKeys.LEAVE_GROUP, request)
+                    .compose(new LeaveGroupResponseHandler());
+            client.pollNoWakeup();
         }
 
-        this.generation = OffsetCommitRequest.DEFAULT_GENERATION_ID;
-        this.memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID;
-        rejoinNeeded = true;
-    }
-
-    private void sendLeaveGroupRequest() {
-        LeaveGroupRequest request = new LeaveGroupRequest(groupId, memberId);
-        RequestFuture<Void> future = client.send(coordinator, ApiKeys.LEAVE_GROUP, request)
-                .compose(new LeaveGroupResponseHandler());
-
-        future.addListener(new RequestFutureListener<Void>() {
-            @Override
-            public void onSuccess(Void value) {}
-
-            @Override
-            public void onFailure(RuntimeException e) {
-                log.debug("LeaveGroup request for group {} failed with error", groupId, e);
-            }
-        });
-
-        client.poll(future, 0);
+        resetGeneration();
     }
 
     private class LeaveGroupResponseHandler extends CoordinatorResponseHandler<LeaveGroupResponse, Void> {
+
         @Override
         public LeaveGroupResponse parse(ClientResponse response) {
             return new LeaveGroupResponse(response.responseBody());
@@ -620,25 +641,26 @@ public abstract class AbstractCoordinator implements Closeable {
 
         @Override
         public void handle(LeaveGroupResponse leaveResponse, RequestFuture<Void> future) {
-            // process the response
-            short errorCode = leaveResponse.errorCode();
-            if (errorCode == Errors.NONE.code())
+            Errors error = Errors.forCode(leaveResponse.errorCode());
+            if (error == Errors.NONE) {
+                log.debug("LeaveGroup request for group {} returned successfully", groupId);
                 future.complete(null);
-            else
-                future.raise(Errors.forCode(errorCode));
+            } else {
+                log.debug("LeaveGroup request for group {} failed with error: {}", groupId, error.message());
+                future.raise(error);
+            }
         }
     }
 
-    /**
-     * Send a heartbeat request now (visible only for testing).
-     */
-    public RequestFuture<Void> sendHeartbeatRequest() {
-        HeartbeatRequest req = new HeartbeatRequest(this.groupId, this.generation, this.memberId);
+    // visible for testing
+    synchronized RequestFuture<Void> sendHeartbeatRequest() {
+        HeartbeatRequest req = new HeartbeatRequest(this.groupId, this.generation.generationId, this.generation.memberId);
         return client.send(coordinator, ApiKeys.HEARTBEAT, req)
-                .compose(new HeartbeatCompletionHandler());
+                .compose(new HeartbeatResponseHandler());
     }
 
-    private class HeartbeatCompletionHandler extends CoordinatorResponseHandler<HeartbeatResponse, Void> {
+    private class HeartbeatResponseHandler extends CoordinatorResponseHandler<HeartbeatResponse, Void> {
+
         @Override
         public HeartbeatResponse parse(ClientResponse response) {
             return new HeartbeatResponse(response.responseBody());
@@ -654,21 +676,20 @@ public abstract class AbstractCoordinator implements Closeable {
             } else if (error == Errors.GROUP_COORDINATOR_NOT_AVAILABLE
                     || error == Errors.NOT_COORDINATOR_FOR_GROUP) {
                 log.debug("Attempt to heart beat failed for group {} since coordinator {} is either not started or not valid.",
-                        groupId, coordinator);
+                        groupId, coordinator());
                 coordinatorDead();
                 future.raise(error);
             } else if (error == Errors.REBALANCE_IN_PROGRESS) {
                 log.debug("Attempt to heart beat failed for group {} since it is rebalancing.", groupId);
-                AbstractCoordinator.this.rejoinNeeded = true;
+                requestRejoin();
                 future.raise(Errors.REBALANCE_IN_PROGRESS);
             } else if (error == Errors.ILLEGAL_GENERATION) {
                 log.debug("Attempt to heart beat failed for group {} since generation id is not legal.", groupId);
-                AbstractCoordinator.this.rejoinNeeded = true;
+                resetGeneration();
                 future.raise(Errors.ILLEGAL_GENERATION);
             } else if (error == Errors.UNKNOWN_MEMBER_ID) {
                 log.debug("Attempt to heart beat failed for group {} since member id is not valid.", groupId);
-                memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID;
-                AbstractCoordinator.this.rejoinNeeded = true;
+                resetGeneration();
                 future.raise(Errors.UNKNOWN_MEMBER_ID);
             } else if (error == Errors.GROUP_AUTHORIZATION_FAILED) {
                 future.raise(new GroupAuthorizationException(groupId));
@@ -678,8 +699,7 @@ public abstract class AbstractCoordinator implements Closeable {
         }
     }
 
-    protected abstract class CoordinatorResponseHandler<R, T>
-            extends RequestFutureAdapter<ClientResponse, T> {
+    protected abstract class CoordinatorResponseHandler<R, T> extends RequestFutureAdapter<ClientResponse, T> {
         protected ClientResponse response;
 
         public abstract R parse(ClientResponse response);
@@ -758,9 +778,149 @@ public abstract class AbstractCoordinator implements Closeable {
                 };
             metrics.addMetric(metrics.metricName("last-heartbeat-seconds-ago",
                 this.metricGrpName,
-                "The number of seconds since the last controller heartbeat"),
+                "The number of seconds since the last controller heartbeat was sent"),
                 lastHeartbeat);
         }
     }
 
+    private class HeartbeatThread extends Thread {
+        private boolean enabled = false;
+        private boolean closed = false;
+        private AtomicReference<RuntimeException> failed = new AtomicReference<>(null);
+
+        public void enable() {
+            synchronized (AbstractCoordinator.this) {
+                this.enabled = true;
+                heartbeat.resetTimeouts(time.milliseconds());
+                AbstractCoordinator.this.notify();
+            }
+        }
+
+        public void disable() {
+            synchronized (AbstractCoordinator.this) {
+                this.enabled = false;
+            }
+        }
+
+        public void close() {
+            synchronized (AbstractCoordinator.this) {
+                this.closed = true;
+                AbstractCoordinator.this.notify();
+            }
+        }
+
+        private boolean hasFailed() {
+            return failed.get() != null;
+        }
+
+        private RuntimeException failureCause() {
+            return failed.get();
+        }
+
+        @Override
+        public void run() {
+            try {
+                RequestFuture findCoordinatorFuture = null;
+
+                while (true) {
+                    synchronized (AbstractCoordinator.this) {
+                        if (closed)
+                            return;
+
+                        if (!enabled) {
+                            AbstractCoordinator.this.wait();
+                            continue;
+                        }
+
+                        if (state != MemberState.STABLE) {
+                            // the group is not stable (perhaps because we left the group or because the coordinator
+                            // kicked us out), so disable heartbeats and wait for the main thread to rejoin.
+                            disable();
+                            continue;
+                        }
+
+                        client.pollNoWakeup();
+                        long now = time.milliseconds();
+
+                        if (coordinatorUnknown()) {
+                            if (findCoordinatorFuture == null || findCoordinatorFuture.isDone())
+                                findCoordinatorFuture = lookupCoordinator();
+                            else
+                                AbstractCoordinator.this.wait(retryBackoffMs);
+                        } else if (heartbeat.sessionTimeoutExpired(now)) {
+                            // the session timeout has expired without seeing a successful heartbeat, so we should
+                            // probably make sure the coordinator is still healthy.
+                            coordinatorDead();
+                        } else if (heartbeat.pollTimeoutExpired(now)) {
+                            // the poll timeout has expired, which means that the foreground thread has stalled
+                            // in between calls to poll(), so we explicitly leave the group.
+                            maybeLeaveGroup();
+                        } else if (!heartbeat.shouldHeartbeat(now)) {
+                            // poll again after waiting for the retry backoff in case the heartbeat failed or the
+                            // coordinator disconnected
+                            AbstractCoordinator.this.wait(retryBackoffMs);
+                        } else {
+                            heartbeat.sentHeartbeat(now);
+
+                            sendHeartbeatRequest().addListener(new RequestFutureListener<Void>() {
+                                @Override
+                                public void onSuccess(Void value) {
+                                    synchronized (AbstractCoordinator.this) {
+                                        heartbeat.receiveHeartbeat(time.milliseconds());
+                                    }
+                                }
+
+                                @Override
+                                public void onFailure(RuntimeException e) {
+                                    synchronized (AbstractCoordinator.this) {
+                                        if (e instanceof RebalanceInProgressException) {
+                                            // it is valid to continue heartbeating while the group is rebalancing. This
+                                            // ensures that the coordinator keeps the member in the group for as long
+                                            // as the duration of the rebalance timeout. If we stop sending heartbeats,
+                                            // however, then the session timeout may expire before we can rejoin.
+                                            heartbeat.receiveHeartbeat(time.milliseconds());
+                                        } else {
+                                            heartbeat.failHeartbeat();
+
+                                            // wake up the thread if it's sleeping to reschedule the heartbeat
+                                            AbstractCoordinator.this.notify();
+                                        }
+                                    }
+                                }
+                            });
+                        }
+                    }
+                }
+            } catch (InterruptedException e) {
+                log.error("Unexpected interrupt received in heartbeat thread for group {}", groupId, e);
+                this.failed.set(new RuntimeException(e));
+            } catch (RuntimeException e) {
+                log.error("Heartbeat thread for group {} failed due to unexpected error" , groupId, e);
+                this.failed.set(e);
+            }
+        }
+
+    }
+
+    protected static class Generation {
+        public static final Generation NO_GENERATION = new Generation(
+                OffsetCommitRequest.DEFAULT_GENERATION_ID,
+                JoinGroupRequest.UNKNOWN_MEMBER_ID,
+                null);
+
+        public final int generationId;
+        public final String memberId;
+        public final String protocol;
+
+        public Generation(int generationId, String memberId, String protocol) {
+            this.generationId = generationId;
+            this.memberId = memberId;
+            this.protocol = protocol;
+        }
+    }
+
+    private static class UnjoinedGroupException extends RetriableException {
+
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/40b1dd3f/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index 81a40f1..5fee45a 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -18,12 +18,13 @@ import org.apache.kafka.clients.consumer.CommitFailedException;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetCommitCallback;
+import org.apache.kafka.clients.consumer.RetriableCommitFailedException;
 import org.apache.kafka.clients.consumer.internals.PartitionAssignor.Assignment;
 import org.apache.kafka.clients.consumer.internals.PartitionAssignor.Subscription;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.clients.consumer.RetriableCommitFailedException;
 import org.apache.kafka.common.errors.GroupAuthorizationException;
 import org.apache.kafka.common.errors.RetriableException;
 import org.apache.kafka.common.errors.TopicAuthorizationException;
@@ -54,6 +55,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ConcurrentLinkedQueue;
 
 /**
  * This class manages the coordination process with the consumer coordinator.
@@ -68,18 +70,24 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     private final SubscriptionState subscriptions;
     private final OffsetCommitCallback defaultOffsetCommitCallback;
     private final boolean autoCommitEnabled;
-    private final AutoCommitTask autoCommitTask;
+    private final int autoCommitIntervalMs;
     private final ConsumerInterceptors<?, ?> interceptors;
     private final boolean excludeInternalTopics;
 
+    // this collection must be thread-safe because it is modified from the response handler
+    // of offset commit requests, which may be invoked from the heartbeat thread
+    private final ConcurrentLinkedQueue<OffsetCommitCompletion> completedOffsetCommits;
+
     private MetadataSnapshot metadataSnapshot;
     private MetadataSnapshot assignmentSnapshot;
+    private long nextAutoCommitDeadline;
 
     /**
      * Initialize the coordination manager.
      */
     public ConsumerCoordinator(ConsumerNetworkClient client,
                                String groupId,
+                               int rebalanceTimeoutMs,
                                int sessionTimeoutMs,
                                int heartbeatIntervalMs,
                                List<PartitionAssignor> assignors,
@@ -91,11 +99,12 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                                long retryBackoffMs,
                                OffsetCommitCallback defaultOffsetCommitCallback,
                                boolean autoCommitEnabled,
-                               long autoCommitIntervalMs,
+                               int autoCommitIntervalMs,
                                ConsumerInterceptors<?, ?> interceptors,
                                boolean excludeInternalTopics) {
         super(client,
                 groupId,
+                rebalanceTimeoutMs,
                 sessionTimeoutMs,
                 heartbeatIntervalMs,
                 metrics,
@@ -103,26 +112,22 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                 time,
                 retryBackoffMs);
         this.metadata = metadata;
-
-        this.metadata.requestUpdate();
         this.metadataSnapshot = new MetadataSnapshot(subscriptions, metadata.fetch());
         this.subscriptions = subscriptions;
         this.defaultOffsetCommitCallback = defaultOffsetCommitCallback;
         this.autoCommitEnabled = autoCommitEnabled;
+        this.autoCommitIntervalMs = autoCommitIntervalMs;
         this.assignors = assignors;
-
-        addMetadataListener();
-
-        if (autoCommitEnabled) {
-            this.autoCommitTask = new AutoCommitTask(autoCommitIntervalMs);
-            this.autoCommitTask.reschedule();
-        } else {
-            this.autoCommitTask = null;
-        }
-
+        this.completedOffsetCommits = new ConcurrentLinkedQueue<>();
         this.sensors = new ConsumerCoordinatorMetrics(metrics, metricGrpPrefix);
         this.interceptors = interceptors;
         this.excludeInternalTopics = excludeInternalTopics;
+
+        if (autoCommitEnabled)
+            this.nextAutoCommitDeadline = time.milliseconds() + autoCommitIntervalMs;
+
+        this.metadata.requestUpdate();
+        addMetadataListener();
     }
 
     @Override
@@ -210,8 +215,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         assignor.onAssignment(assignment);
 
         // reschedule the auto commit starting from now
-        if (autoCommitEnabled)
-            autoCommitTask.reschedule();
+        this.nextAutoCommitDeadline = time.milliseconds() + autoCommitIntervalMs;
 
         // execute the user's callback after rebalance
         ConsumerRebalanceListener listener = subscriptions.listener();
@@ -227,6 +231,54 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         }
     }
 
+    /**
+     * Poll for coordinator events. This ensures that the coordinator is known and that the consumer
+     * has joined the group (if it is using group management). This also handles periodic offset commits
+     * if they are enabled.
+     *
+     * @param now current time in milliseconds
+     */
+    public void poll(long now) {
+        invokeCompletedOffsetCommitCallbacks();
+
+        if (subscriptions.partitionsAutoAssigned() && coordinatorUnknown()) {
+            ensureCoordinatorReady();
+            now = time.milliseconds();
+        }
+
+        if (subscriptions.partitionsAutoAssigned() && needRejoin()) {
+            // due to a race condition between the initial metadata fetch and the initial rebalance, we need to ensure that
+            // the metadata is fresh before joining initially, and then request the metadata update. If metadata update arrives
+            // while the rebalance is still pending (for example, when the join group is still inflight), then we will lose
+            // track of the fact that we need to rebalance again to reflect the change to the topic subscription. Without
+            // ensuring that the metadata is fresh, any metadata update that changes the topic subscriptions and arrives with a
+            // rebalance in progress will essentially be ignored. See KAFKA-3949 for the complete description of the problem.
+            if (subscriptions.hasPatternSubscription())
+                client.ensureFreshMetadata();
+
+            ensureActiveGroup();
+            now = time.milliseconds();
+        }
+
+        pollHeartbeat(now);
+        maybeAutoCommitOffsetsAsync(now);
+    }
+
+    /**
+     * Return the time to the next needed invocation of {@link #poll(long)}.
+     * @param now current time in milliseconds
+     * @return the maximum time in milliseconds the caller should wait before the next invocation of poll()
+     */
+    public long timeToNextPoll(long now) {
+        if (!autoCommitEnabled)
+            return timeToNextHeartbeat(now);
+
+        if (now > nextAutoCommitDeadline)
+            return 0;
+
+        return Math.min(nextAutoCommitDeadline - now, timeToNextHeartbeat(now));
+    }
+
     @Override
     protected Map<String, ByteBuffer> performAssignment(String leaderId,
                                                         String assignmentStrategy,
@@ -292,7 +344,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     }
 
     @Override
-    public boolean needRejoin() {
+    protected boolean needRejoin() {
         return subscriptions.partitionsAutoAssigned() &&
                 (super.needRejoin() || subscriptions.partitionAssignmentNeeded());
     }
@@ -336,24 +388,6 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         }
     }
 
-    /**
-     * Ensure that we have a valid partition assignment from the coordinator.
-     */
-    public void ensurePartitionAssignment() {
-        if (subscriptions.partitionsAutoAssigned()) {
-            // Due to a race condition between the initial metadata fetch and the initial rebalance, we need to ensure that
-            // the metadata is fresh before joining initially, and then request the metadata update. If metadata update arrives
-            // while the rebalance is still pending (for example, when the join group is still inflight), then we will lose
-            // track of the fact that we need to rebalance again to reflect the change to the topic subscription. Without
-            // ensuring that the metadata is fresh, any metadata update that changes the topic subscriptions and arrives with a
-            // rebalance in progress will essentially be ignored. See KAFKA-3949 for the complete description of the problem.
-            if (subscriptions.hasPatternSubscription())
-                client.ensureFreshMetadata();
-
-            ensureActiveGroup();
-        }
-    }
-
     @Override
     public void close() {
         // we do not need to re-enable wakeups since we are closing already
@@ -365,8 +399,20 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         }
     }
 
+    // visible for testing
+    void invokeCompletedOffsetCommitCallbacks() {
+        while (true) {
+            OffsetCommitCompletion completion = completedOffsetCommits.poll();
+            if (completion == null)
+                break;
+            completion.invoke();
+        }
+    }
+
 
     public void commitOffsetsAsync(final Map<TopicPartition, OffsetAndMetadata> offsets, final OffsetCommitCallback callback) {
+        invokeCompletedOffsetCommitCallbacks();
+
         if (!coordinatorUnknown()) {
             doCommitOffsetsAsync(offsets, callback);
         } else {
@@ -384,7 +430,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
 
                 @Override
                 public void onFailure(RuntimeException e) {
-                    callback.onComplete(offsets, new RetriableCommitFailedException(e));
+                    completedOffsetCommits.add(new OffsetCommitCompletion(callback, offsets, new RetriableCommitFailedException(e)));
                 }
             });
         }
@@ -404,16 +450,18 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             public void onSuccess(Void value) {
                 if (interceptors != null)
                     interceptors.onCommit(offsets);
-                cb.onComplete(offsets, null);
+
+                completedOffsetCommits.add(new OffsetCommitCompletion(cb, offsets, null));
             }
 
             @Override
             public void onFailure(RuntimeException e) {
-                if (e instanceof RetriableException) {
-                    cb.onComplete(offsets, new RetriableCommitFailedException(e));
-                } else {
-                    cb.onComplete(offsets, e);
-                }
+                Exception commitException = e;
+
+                if (e instanceof RetriableException)
+                    commitException = new RetriableCommitFailedException(e);
+
+                completedOffsetCommits.add(new OffsetCommitCompletion(cb, offsets, commitException));
             }
         });
     }
@@ -427,6 +475,8 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
      * @throws CommitFailedException if an unrecoverable error occurs before the commit can be completed
      */
     public void commitOffsetsSync(Map<TopicPartition, OffsetAndMetadata> offsets) {
+        invokeCompletedOffsetCommitCallbacks();
+
         if (offsets.isEmpty())
             return;
 
@@ -449,46 +499,25 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         }
     }
 
-    private class AutoCommitTask implements DelayedTask {
-        private final long interval;
-
-        public AutoCommitTask(long interval) {
-            this.interval = interval;
-        }
-
-        private void reschedule() {
-            client.schedule(this, time.milliseconds() + interval);
-        }
-
-        private void reschedule(long at) {
-            client.schedule(this, at);
-        }
-
-        public void run(final long now) {
+    private void maybeAutoCommitOffsetsAsync(long now) {
+        if (autoCommitEnabled) {
             if (coordinatorUnknown()) {
-                log.debug("Cannot auto-commit offsets for group {} since the coordinator is unknown", groupId);
-                reschedule(now + retryBackoffMs);
-                return;
-            }
-
-            if (needRejoin()) {
-                // skip the commit when we're rejoining since we'll commit offsets synchronously
-                // before the revocation callback is invoked
-                reschedule(now + interval);
-                return;
-            }
-
-            commitOffsetsAsync(subscriptions.allConsumed(), new OffsetCommitCallback() {
-                @Override
-                public void onComplete(Map<TopicPartition, OffsetAndMetadata> offsets, Exception exception) {
-                    if (exception == null) {
-                        reschedule(now + interval);
-                    } else {
-                        log.warn("Auto offset commit failed for group {}: {}", groupId, exception.getMessage());
-                        reschedule(now + interval);
+                this.nextAutoCommitDeadline = now + retryBackoffMs;
+            } else if (now >= nextAutoCommitDeadline) {
+                this.nextAutoCommitDeadline = now + autoCommitIntervalMs;
+                commitOffsetsAsync(subscriptions.allConsumed(), new OffsetCommitCallback() {
+                    @Override
+                    public void onComplete(Map<TopicPartition, OffsetAndMetadata> offsets, Exception exception) {
+                        if (exception != null) {
+                            log.warn("Auto offset commit failed for group {}: {}", groupId, exception.getMessage());
+                            if (exception instanceof RetriableException)
+                                nextAutoCommitDeadline = Math.min(time.milliseconds() + retryBackoffMs, nextAutoCommitDeadline);
+                        } else {
+                            log.debug("Completed autocommit of offsets {} for group {}", offsets, groupId);
+                        }
                     }
-                }
-            });
+                });
+            }
         }
     }
 
@@ -506,6 +535,14 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         }
     }
 
+    public static class DefaultOffsetCommitCallback implements OffsetCommitCallback {
+        @Override
+        public void onComplete(Map<TopicPartition, OffsetAndMetadata> offsets, Exception exception) {
+            if (exception != null)
+                log.error("Offset commit failed.", exception);
+        }
+    }
+
     /**
      * Commit offsets for the specified list of topics and partitions. This is a non-blocking call
      * which returns a request future that can be polled in the case of a synchronous commit or ignored in the
@@ -515,12 +552,13 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
      * @return A request future whose value indicates whether the commit was successful or not
      */
     private RequestFuture<Void> sendOffsetCommitRequest(final Map<TopicPartition, OffsetAndMetadata> offsets) {
-        if (coordinatorUnknown())
-            return RequestFuture.coordinatorNotAvailable();
-
         if (offsets.isEmpty())
             return RequestFuture.voidSuccess();
 
+        Node coordinator = coordinator();
+        if (coordinator == null)
+            return RequestFuture.coordinatorNotAvailable();
+
         // create the offset commit request
         Map<TopicPartition, OffsetCommitRequest.PartitionData> offsetData = new HashMap<>(offsets.size());
         for (Map.Entry<TopicPartition, OffsetAndMetadata> entry : offsets.entrySet()) {
@@ -529,9 +567,21 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                     offsetAndMetadata.offset(), offsetAndMetadata.metadata()));
         }
 
-        OffsetCommitRequest req = new OffsetCommitRequest(this.groupId,
-                this.generation,
-                this.memberId,
+        final Generation generation;
+        if (subscriptions.partitionsAutoAssigned())
+            generation = generation();
+        else
+            generation = Generation.NO_GENERATION;
+
+        // if the generation is null, we are not part of an active group (and we expect to be).
+        // the only thing we can do is fail the commit and let the user rejoin the group in poll()
+        if (generation == null)
+            return RequestFuture.failure(new CommitFailedException());
+
+        OffsetCommitRequest req = new OffsetCommitRequest(
+                this.groupId,
+                generation.generationId,
+                generation.memberId,
                 OffsetCommitRequest.DEFAULT_RETENTION_TIME,
                 offsetData);
 
@@ -541,14 +591,6 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                 .compose(new OffsetCommitResponseHandler(offsets));
     }
 
-    public static class DefaultOffsetCommitCallback implements OffsetCommitCallback {
-        @Override
-        public void onComplete(Map<TopicPartition, OffsetAndMetadata> offsets, Exception exception) {
-            if (exception != null)
-                log.error("Offset commit failed.", exception);
-        }
-    }
-
     private class OffsetCommitResponseHandler extends CoordinatorResponseHandler<OffsetCommitResponse, Void> {
 
         private final Map<TopicPartition, OffsetAndMetadata> offsets;
@@ -607,13 +649,8 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                         || error == Errors.REBALANCE_IN_PROGRESS) {
                     // need to re-join group
                     log.debug("Offset commit for group {} failed: {}", groupId, error.message());
-                    subscriptions.needReassignment();
-                    future.raise(new CommitFailedException("Commit cannot be completed since the group has already " +
-                            "rebalanced and assigned the partitions to another member. This means that the time " +
-                            "between subsequent calls to poll() was longer than the configured session.timeout.ms, " +
-                            "which typically implies that the poll loop is spending too much time message processing. " +
-                            "You can address this either by increasing the session timeout or by reducing the maximum " +
-                            "size of batches returned in poll() with max.poll.records."));
+                    resetGeneration();
+                    future.raise(new CommitFailedException());
                     return;
                 } else {
                     log.error("Group {} failed to commit partition {} at offset {}: {}", groupId, tp, offset, error.message());
@@ -639,7 +676,8 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
      * @return A request future containing the committed offsets.
      */
     private RequestFuture<Map<TopicPartition, OffsetAndMetadata>> sendOffsetFetchRequest(Set<TopicPartition> partitions) {
-        if (coordinatorUnknown())
+        Node coordinator = coordinator();
+        if (coordinator == null)
             return RequestFuture.coordinatorNotAvailable();
 
         log.debug("Group {} fetching committed offsets for partitions: {}", groupId, partitions);
@@ -675,11 +713,6 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                         // re-discover the coordinator and retry
                         coordinatorDead();
                         future.raise(error);
-                    } else if (error == Errors.UNKNOWN_MEMBER_ID
-                            || error == Errors.ILLEGAL_GENERATION) {
-                        // need to re-join group
-                        subscriptions.needReassignment();
-                        future.raise(error);
                     } else {
                         future.raise(new KafkaException("Unexpected error in fetch offset response: " + error.message()));
                     }
@@ -753,5 +786,20 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         }
     }
 
+    private static class OffsetCommitCompletion {
+        private final OffsetCommitCallback callback;
+        private final Map<TopicPartition, OffsetAndMetadata> offsets;
+        private final Exception exception;
+
+        public OffsetCommitCompletion(OffsetCommitCallback callback, Map<TopicPartition, OffsetAndMetadata> offsets, Exception exception) {
+            this.callback = callback;
+            this.offsets = offsets;
+            this.exception = exception;
+        }
+
+        public void invoke() {
+            callback.onComplete(offsets, exception);
+        }
+    }
 
 }