You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by va...@apache.org on 2018/05/22 20:43:51 UTC

spark git commit: [SPARK-19185][DSTREAMS] Avoid concurrent use of cached consumers in CachedKafkaConsumer

Repository: spark
Updated Branches:
  refs/heads/master bc6ea614a -> 79e06faa4


[SPARK-19185][DSTREAMS] Avoid concurrent use of cached consumers in CachedKafkaConsumer

## What changes were proposed in this pull request?

`CachedKafkaConsumer` in the project streaming-kafka-0-10 is designed to maintain a pool of KafkaConsumers that can be reused. However, it was built with the assumption there will be only one thread trying to read the same Kafka TopicPartition at the same time. This assumption is not true all the time and this can inadvertently lead to ConcurrentModificationException.

Here is a better way to design this. The consumer pool should be smart enough to avoid concurrent use of a cached consumer. If there is another request for the same TopicPartition as a currently in-use consumer, the pool should automatically return a fresh consumer.

- There are effectively two kinds of consumer that may be generated
  - Cached consumer - this should be returned to the pool at task end
  - Non-cached consumer - this should be closed at task end
- A trait called `KafkaDataConsumer` is introduced to hide this difference from the users of the consumer so that the client code does not have to reason about whether to stop and release. They simply call `val consumer = KafkaDataConsumer.acquire` and then `consumer.release`.
- If there is request for a consumer that is in-use, then a new consumer is generated.
- If there is request for a consumer which is a task reattempt, then already existing cached consumer will be invalidated and a new consumer is generated. This could fix potential issues if the source of the reattempt is a malfunctioning consumer.
- In addition, I renamed the `CachedKafkaConsumer` class to `KafkaDataConsumer` because is a misnomer given that what it returns may or may not be cached.

## How was this patch tested?

A new stress test that verifies it is safe to concurrently get consumers for the same TopicPartition from the consumer pool.

Author: Gabor Somogyi <ga...@gmail.com>

Closes #20997 from gaborgsomogyi/SPARK-19185.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/79e06faa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/79e06faa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/79e06faa

Branch: refs/heads/master
Commit: 79e06faa4ef6596c9e2d4be09c74b935064021bb
Parents: bc6ea61
Author: Gabor Somogyi <ga...@gmail.com>
Authored: Tue May 22 13:43:45 2018 -0700
Committer: Marcelo Vanzin <va...@cloudera.com>
Committed: Tue May 22 13:43:45 2018 -0700

----------------------------------------------------------------------
 .../spark/sql/kafka010/KafkaDataConsumer.scala  |   2 +-
 .../kafka010/CachedKafkaConsumer.scala          | 226 ------------
 .../streaming/kafka010/KafkaDataConsumer.scala  | 359 +++++++++++++++++++
 .../spark/streaming/kafka010/KafkaRDD.scala     |  20 +-
 .../kafka010/KafkaDataConsumerSuite.scala       | 131 +++++++
 5 files changed, 496 insertions(+), 242 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/79e06faa/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
index 48508d0..941f0ab 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
@@ -395,7 +395,7 @@ private[kafka010] object KafkaDataConsumer extends Logging {
         // likely running on a beefy machine that can handle a large number of simultaneously
         // active consumers.
 
-        if (entry.getValue.inUse == false && this.size > capacity) {
+        if (!entry.getValue.inUse && this.size > capacity) {
           logWarning(
             s"KafkaConsumer cache hitting max capacity of $capacity, " +
               s"removing consumer for ${entry.getKey}")

http://git-wip-us.apache.org/repos/asf/spark/blob/79e06faa/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
deleted file mode 100644
index aeb8c1d..0000000
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
+++ /dev/null
@@ -1,226 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.kafka010
-
-import java.{ util => ju }
-
-import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer }
-import org.apache.kafka.common.{ KafkaException, TopicPartition }
-
-import org.apache.spark.internal.Logging
-
-/**
- * Consumer of single topicpartition, intended for cached reuse.
- * Underlying consumer is not threadsafe, so neither is this,
- * but processing the same topicpartition and group id in multiple threads is usually bad anyway.
- */
-private[kafka010]
-class CachedKafkaConsumer[K, V] private(
-  val groupId: String,
-  val topic: String,
-  val partition: Int,
-  val kafkaParams: ju.Map[String, Object]) extends Logging {
-
-  require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG),
-    "groupId used for cache key must match the groupId in kafkaParams")
-
-  val topicPartition = new TopicPartition(topic, partition)
-
-  protected val consumer = {
-    val c = new KafkaConsumer[K, V](kafkaParams)
-    val tps = new ju.ArrayList[TopicPartition]()
-    tps.add(topicPartition)
-    c.assign(tps)
-    c
-  }
-
-  // TODO if the buffer was kept around as a random-access structure,
-  // could possibly optimize re-calculating of an RDD in the same batch
-  protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]()
-  protected var nextOffset = -2L
-
-  def close(): Unit = consumer.close()
-
-  /**
-   * Get the record for the given offset, waiting up to timeout ms if IO is necessary.
-   * Sequential forward access will use buffers, but random access will be horribly inefficient.
-   */
-  def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = {
-    logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset")
-    if (offset != nextOffset) {
-      logInfo(s"Initial fetch for $groupId $topic $partition $offset")
-      seek(offset)
-      poll(timeout)
-    }
-
-    if (!buffer.hasNext()) { poll(timeout) }
-    require(buffer.hasNext(),
-      s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout")
-    var record = buffer.next()
-
-    if (record.offset != offset) {
-      logInfo(s"Buffer miss for $groupId $topic $partition $offset")
-      seek(offset)
-      poll(timeout)
-      require(buffer.hasNext(),
-        s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout")
-      record = buffer.next()
-      require(record.offset == offset,
-        s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " +
-          s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " +
-          "spark.streaming.kafka.allowNonConsecutiveOffsets"
-      )
-    }
-
-    nextOffset = offset + 1
-    record
-  }
-
-  /**
-   * Start a batch on a compacted topic
-   */
-  def compactedStart(offset: Long, timeout: Long): Unit = {
-    logDebug(s"compacted start $groupId $topic $partition starting $offset")
-    // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics
-    if (offset != nextOffset) {
-      logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset")
-      seek(offset)
-      poll(timeout)
-    }
-  }
-
-  /**
-   * Get the next record in the batch from a compacted topic.
-   * Assumes compactedStart has been called first, and ignores gaps.
-   */
-  def compactedNext(timeout: Long): ConsumerRecord[K, V] = {
-    if (!buffer.hasNext()) {
-      poll(timeout)
-    }
-    require(buffer.hasNext(),
-      s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout")
-    val record = buffer.next()
-    nextOffset = record.offset + 1
-    record
-  }
-
-  /**
-   * Rewind to previous record in the batch from a compacted topic.
-   * @throws NoSuchElementException if no previous element
-   */
-  def compactedPrevious(): ConsumerRecord[K, V] = {
-    buffer.previous()
-  }
-
-  private def seek(offset: Long): Unit = {
-    logDebug(s"Seeking to $topicPartition $offset")
-    consumer.seek(topicPartition, offset)
-  }
-
-  private def poll(timeout: Long): Unit = {
-    val p = consumer.poll(timeout)
-    val r = p.records(topicPartition)
-    logDebug(s"Polled ${p.partitions()}  ${r.size}")
-    buffer = r.listIterator
-  }
-
-}
-
-private[kafka010]
-object CachedKafkaConsumer extends Logging {
-
-  private case class CacheKey(groupId: String, topic: String, partition: Int)
-
-  // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap
-  private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null
-
-  /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */
-  def init(
-      initialCapacity: Int,
-      maxCapacity: Int,
-      loadFactor: Float): Unit = CachedKafkaConsumer.synchronized {
-    if (null == cache) {
-      logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor")
-      cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]](
-        initialCapacity, loadFactor, true) {
-        override def removeEldestEntry(
-          entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = {
-          if (this.size > maxCapacity) {
-            try {
-              entry.getValue.consumer.close()
-            } catch {
-              case x: KafkaException =>
-                logError("Error closing oldest Kafka consumer", x)
-            }
-            true
-          } else {
-            false
-          }
-        }
-      }
-    }
-  }
-
-  /**
-   * Get a cached consumer for groupId, assigned to topic and partition.
-   * If matching consumer doesn't already exist, will be created using kafkaParams.
-   */
-  def get[K, V](
-      groupId: String,
-      topic: String,
-      partition: Int,
-      kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] =
-    CachedKafkaConsumer.synchronized {
-      val k = CacheKey(groupId, topic, partition)
-      val v = cache.get(k)
-      if (null == v) {
-        logInfo(s"Cache miss for $k")
-        logDebug(cache.keySet.toString)
-        val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams)
-        cache.put(k, c)
-        c
-      } else {
-        // any given topicpartition should have a consistent key and value type
-        v.asInstanceOf[CachedKafkaConsumer[K, V]]
-      }
-    }
-
-  /**
-   * Get a fresh new instance, unassociated with the global cache.
-   * Caller is responsible for closing
-   */
-  def getUncached[K, V](
-      groupId: String,
-      topic: String,
-      partition: Int,
-      kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] =
-    new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams)
-
-  /** remove consumer for given groupId, topic, and partition, if it exists */
-  def remove(groupId: String, topic: String, partition: Int): Unit = {
-    val k = CacheKey(groupId, topic, partition)
-    logInfo(s"Removing $k from cache")
-    val v = CachedKafkaConsumer.synchronized {
-      cache.remove(k)
-    }
-    if (null != v) {
-      v.close()
-      logInfo(s"Removed $k from cache")
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/79e06faa/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
new file mode 100644
index 0000000..68c5fe9
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
@@ -0,0 +1,359 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{util => ju}
+
+import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer}
+import org.apache.kafka.common.{KafkaException, TopicPartition}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+
+private[kafka010] sealed trait KafkaDataConsumer[K, V] {
+  /**
+   * Get the record for the given offset if available.
+   *
+   * @param offset         the offset to fetch.
+   * @param pollTimeoutMs  timeout in milliseconds to poll data from Kafka.
+   */
+  def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+    internalConsumer.get(offset, pollTimeoutMs)
+  }
+
+  /**
+   * Start a batch on a compacted topic
+   *
+   * @param offset         the offset to fetch.
+   * @param pollTimeoutMs  timeout in milliseconds to poll data from Kafka.
+   */
+  def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = {
+    internalConsumer.compactedStart(offset, pollTimeoutMs)
+  }
+
+  /**
+   * Get the next record in the batch from a compacted topic.
+   * Assumes compactedStart has been called first, and ignores gaps.
+   *
+   * @param pollTimeoutMs  timeout in milliseconds to poll data from Kafka.
+   */
+  def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+    internalConsumer.compactedNext(pollTimeoutMs)
+  }
+
+  /**
+   * Rewind to previous record in the batch from a compacted topic.
+   *
+   * @throws NoSuchElementException if no previous element
+   */
+  def compactedPrevious(): ConsumerRecord[K, V] = {
+    internalConsumer.compactedPrevious()
+  }
+
+  /**
+   * Release this consumer from being further used. Depending on its implementation,
+   * this consumer will be either finalized, or reset for reuse later.
+   */
+  def release(): Unit
+
+  /** Reference to the internal implementation that this wrapper delegates to */
+  def internalConsumer: InternalKafkaConsumer[K, V]
+}
+
+
+/**
+ * A wrapper around Kafka's KafkaConsumer.
+ * This is not for direct use outside this file.
+ */
+private[kafka010] class InternalKafkaConsumer[K, V](
+    val topicPartition: TopicPartition,
+    val kafkaParams: ju.Map[String, Object]) extends Logging {
+
+  private[kafka010] val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG)
+    .asInstanceOf[String]
+
+  private val consumer = createConsumer
+
+  /** indicates whether this consumer is in use or not */
+  var inUse = true
+
+  /** indicate whether this consumer is going to be stopped in the next release */
+  var markedForClose = false
+
+  // TODO if the buffer was kept around as a random-access structure,
+  // could possibly optimize re-calculating of an RDD in the same batch
+  @volatile private var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]()
+  @volatile private var nextOffset = InternalKafkaConsumer.UNKNOWN_OFFSET
+
+  override def toString: String = {
+    "InternalKafkaConsumer(" +
+      s"hash=${Integer.toHexString(hashCode)}, " +
+      s"groupId=$groupId, " +
+      s"topicPartition=$topicPartition)"
+  }
+
+  /** Create a KafkaConsumer to fetch records for `topicPartition` */
+  private def createConsumer: KafkaConsumer[K, V] = {
+    val c = new KafkaConsumer[K, V](kafkaParams)
+    val topics = ju.Arrays.asList(topicPartition)
+    c.assign(topics)
+    c
+  }
+
+  def close(): Unit = consumer.close()
+
+  /**
+   * Get the record for the given offset, waiting up to timeout ms if IO is necessary.
+   * Sequential forward access will use buffers, but random access will be horribly inefficient.
+   */
+  def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = {
+    logDebug(s"Get $groupId $topicPartition nextOffset $nextOffset requested $offset")
+    if (offset != nextOffset) {
+      logInfo(s"Initial fetch for $groupId $topicPartition $offset")
+      seek(offset)
+      poll(timeout)
+    }
+
+    if (!buffer.hasNext()) {
+      poll(timeout)
+    }
+    require(buffer.hasNext(),
+      s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout")
+    var record = buffer.next()
+
+    if (record.offset != offset) {
+      logInfo(s"Buffer miss for $groupId $topicPartition $offset")
+      seek(offset)
+      poll(timeout)
+      require(buffer.hasNext(),
+        s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout")
+      record = buffer.next()
+      require(record.offset == offset,
+        s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset " +
+          s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " +
+          "spark.streaming.kafka.allowNonConsecutiveOffsets"
+      )
+    }
+
+    nextOffset = offset + 1
+    record
+  }
+
+  /**
+   * Start a batch on a compacted topic
+   */
+  def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = {
+    logDebug(s"compacted start $groupId $topicPartition starting $offset")
+    // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics
+    if (offset != nextOffset) {
+      logInfo(s"Initial fetch for compacted $groupId $topicPartition $offset")
+      seek(offset)
+      poll(pollTimeoutMs)
+    }
+  }
+
+  /**
+   * Get the next record in the batch from a compacted topic.
+   * Assumes compactedStart has been called first, and ignores gaps.
+   */
+  def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+    if (!buffer.hasNext()) {
+      poll(pollTimeoutMs)
+    }
+    require(buffer.hasNext(),
+      s"Failed to get records for compacted $groupId $topicPartition " +
+        s"after polling for $pollTimeoutMs")
+    val record = buffer.next()
+    nextOffset = record.offset + 1
+    record
+  }
+
+  /**
+   * Rewind to previous record in the batch from a compacted topic.
+   * @throws NoSuchElementException if no previous element
+   */
+  def compactedPrevious(): ConsumerRecord[K, V] = {
+    buffer.previous()
+  }
+
+  private def seek(offset: Long): Unit = {
+    logDebug(s"Seeking to $topicPartition $offset")
+    consumer.seek(topicPartition, offset)
+  }
+
+  private def poll(timeout: Long): Unit = {
+    val p = consumer.poll(timeout)
+    val r = p.records(topicPartition)
+    logDebug(s"Polled ${p.partitions()}  ${r.size}")
+    buffer = r.listIterator
+  }
+
+}
+
+private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition)
+
+private[kafka010] object KafkaDataConsumer extends Logging {
+
+  private case class CachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V])
+    extends KafkaDataConsumer[K, V] {
+    assert(internalConsumer.inUse)
+    override def release(): Unit = KafkaDataConsumer.release(internalConsumer)
+  }
+
+  private case class NonCachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V])
+    extends KafkaDataConsumer[K, V] {
+    override def release(): Unit = internalConsumer.close()
+  }
+
+  // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap
+  private[kafka010] var cache: ju.Map[CacheKey, InternalKafkaConsumer[_, _]] = null
+
+  /**
+   * Must be called before acquire, once per JVM, to configure the cache.
+   * Further calls are ignored.
+   */
+  def init(
+      initialCapacity: Int,
+      maxCapacity: Int,
+      loadFactor: Float): Unit = synchronized {
+    if (null == cache) {
+      logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor")
+      cache = new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer[_, _]](
+        initialCapacity, loadFactor, true) {
+        override def removeEldestEntry(
+            entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer[_, _]]): Boolean = {
+
+          // Try to remove the least-used entry if its currently not in use.
+          //
+          // If you cannot remove it, then the cache will keep growing. In the worst case,
+          // the cache will grow to the max number of concurrent tasks that can run in the executor,
+          // (that is, number of tasks slots) after which it will never reduce. This is unlikely to
+          // be a serious problem because an executor with more than 64 (default) tasks slots is
+          // likely running on a beefy machine that can handle a large number of simultaneously
+          // active consumers.
+
+          if (entry.getValue.inUse == false && this.size > maxCapacity) {
+            logWarning(
+                s"KafkaConsumer cache hitting max capacity of $maxCapacity, " +
+                s"removing consumer for ${entry.getKey}")
+               try {
+              entry.getValue.close()
+            } catch {
+              case x: KafkaException =>
+                logError("Error closing oldest Kafka consumer", x)
+            }
+            true
+          } else {
+            false
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Get a cached consumer for groupId, assigned to topic and partition.
+   * If matching consumer doesn't already exist, will be created using kafkaParams.
+   * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]].
+   *
+   * Note: This method guarantees that the consumer returned is not currently in use by anyone
+   * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by
+   * caching them and tracking when they are in use.
+   */
+  def acquire[K, V](
+      topicPartition: TopicPartition,
+      kafkaParams: ju.Map[String, Object],
+      context: TaskContext,
+      useCache: Boolean): KafkaDataConsumer[K, V] = synchronized {
+    val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+    val key = new CacheKey(groupId, topicPartition)
+    val existingInternalConsumer = cache.get(key)
+
+    lazy val newInternalConsumer = new InternalKafkaConsumer[K, V](topicPartition, kafkaParams)
+
+    if (context != null && context.attemptNumber >= 1) {
+      // If this is reattempt at running the task, then invalidate cached consumers if any and
+      // start with a new one. If prior attempt failures were cache related then this way old
+      // problematic consumers can be removed.
+      logDebug(s"Reattempt detected, invalidating cached consumer $existingInternalConsumer")
+      if (existingInternalConsumer != null) {
+        // Consumer exists in cache. If its in use, mark it for closing later, or close it now.
+        if (existingInternalConsumer.inUse) {
+          existingInternalConsumer.markedForClose = true
+        } else {
+          existingInternalConsumer.close()
+          // Remove the consumer from cache only if it's closed.
+          // Marked for close consumers will be removed in release function.
+          cache.remove(key)
+        }
+      }
+
+      logDebug("Reattempt detected, new non-cached consumer will be allocated " +
+        s"$newInternalConsumer")
+      NonCachedKafkaDataConsumer(newInternalConsumer)
+    } else if (!useCache) {
+      // If consumer reuse turned off, then do not use it, return a new consumer
+      logDebug("Cache usage turned off, new non-cached consumer will be allocated " +
+        s"$newInternalConsumer")
+      NonCachedKafkaDataConsumer(newInternalConsumer)
+    } else if (existingInternalConsumer == null) {
+      // If consumer is not already cached, then put a new in the cache and return it
+      logDebug("No cached consumer, new cached consumer will be allocated " +
+        s"$newInternalConsumer")
+      cache.put(key, newInternalConsumer)
+      CachedKafkaDataConsumer(newInternalConsumer)
+    } else if (existingInternalConsumer.inUse) {
+      // If consumer is already cached but is currently in use, then return a new consumer
+      logDebug("Used cached consumer found, new non-cached consumer will be allocated " +
+        s"$newInternalConsumer")
+      NonCachedKafkaDataConsumer(newInternalConsumer)
+    } else {
+      // If consumer is already cached and is currently not in use, then return that consumer
+      logDebug(s"Not used cached consumer found, re-using it $existingInternalConsumer")
+      existingInternalConsumer.inUse = true
+      // Any given TopicPartition should have a consistent key and value type
+      CachedKafkaDataConsumer(existingInternalConsumer.asInstanceOf[InternalKafkaConsumer[K, V]])
+    }
+  }
+
+  private def release(internalConsumer: InternalKafkaConsumer[_, _]): Unit = synchronized {
+    // Clear the consumer from the cache if this is indeed the consumer present in the cache
+    val key = new CacheKey(internalConsumer.groupId, internalConsumer.topicPartition)
+    val cachedInternalConsumer = cache.get(key)
+    if (internalConsumer.eq(cachedInternalConsumer)) {
+      // The released consumer is the same object as the cached one.
+      if (internalConsumer.markedForClose) {
+        internalConsumer.close()
+        cache.remove(key)
+      } else {
+        internalConsumer.inUse = false
+      }
+    } else {
+      // The released consumer is either not the same one as in the cache, or not in the cache
+      // at all. This may happen if the cache was invalidate while this consumer was being used.
+      // Just close this consumer.
+      internalConsumer.close()
+      logInfo(s"Released a supposedly cached consumer that was not found in the cache " +
+        s"$internalConsumer")
+    }
+  }
+}
+
+private[kafka010] object InternalKafkaConsumer {
+  private val UNKNOWN_OFFSET = -2L
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/79e06faa/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
index 07239ed..81abc98 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
@@ -19,8 +19,6 @@ package org.apache.spark.streaming.kafka010
 
 import java.{ util => ju }
 
-import scala.collection.mutable.ArrayBuffer
-
 import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord }
 import org.apache.kafka.common.TopicPartition
 
@@ -239,26 +237,18 @@ private class KafkaRDDIterator[K, V](
   cacheLoadFactor: Float
 ) extends Iterator[ConsumerRecord[K, V]] {
 
-  val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
-
   context.addTaskCompletionListener(_ => closeIfNeeded())
 
-  val consumer = if (useConsumerCache) {
-    CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor)
-    if (context.attemptNumber >= 1) {
-      // just in case the prior attempt failures were cache related
-      CachedKafkaConsumer.remove(groupId, part.topic, part.partition)
-    }
-    CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams)
-  } else {
-    CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams)
+  val consumer = {
+    KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor)
+    KafkaDataConsumer.acquire[K, V](part.topicPartition(), kafkaParams, context, useConsumerCache)
   }
 
   var requestOffset = part.fromOffset
 
   def closeIfNeeded(): Unit = {
-    if (!useConsumerCache && consumer != null) {
-      consumer.close()
+    if (consumer != null) {
+      consumer.release()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/79e06faa/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
new file mode 100644
index 0000000..d934c64
--- /dev/null
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.util.concurrent.{Executors, TimeUnit}
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+import org.apache.kafka.clients.consumer.ConsumerConfig._
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.serialization.ByteArrayDeserializer
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark._
+
+class KafkaDataConsumerSuite extends SparkFunSuite with BeforeAndAfterAll {
+  private var testUtils: KafkaTestUtils = _
+  private val topic = "topic" + Random.nextInt()
+  private val topicPartition = new TopicPartition(topic, 0)
+  private val groupId = "groupId"
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    testUtils = new KafkaTestUtils
+    testUtils.setup()
+    KafkaDataConsumer.init(16, 64, 0.75f)
+  }
+
+  override def afterAll(): Unit = {
+    if (testUtils != null) {
+      testUtils.teardown()
+      testUtils = null
+    }
+    super.afterAll()
+  }
+
+  private def getKafkaParams() = Map[String, Object](
+    GROUP_ID_CONFIG -> groupId,
+    BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
+    KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+    VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+    AUTO_OFFSET_RESET_CONFIG -> "earliest",
+    ENABLE_AUTO_COMMIT_CONFIG -> "false"
+  ).asJava
+
+  test("KafkaDataConsumer reuse in case of same groupId and TopicPartition") {
+    KafkaDataConsumer.cache.clear()
+
+    val kafkaParams = getKafkaParams()
+
+    val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+      topicPartition, kafkaParams, null, true)
+    consumer1.release()
+
+    val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+      topicPartition, kafkaParams, null, true)
+    consumer2.release()
+
+    assert(KafkaDataConsumer.cache.size() == 1)
+    val key = new CacheKey(groupId, topicPartition)
+    val existingInternalConsumer = KafkaDataConsumer.cache.get(key)
+    assert(existingInternalConsumer.eq(consumer1.internalConsumer))
+    assert(existingInternalConsumer.eq(consumer2.internalConsumer))
+  }
+
+  test("concurrent use of KafkaDataConsumer") {
+    val data = (1 to 1000).map(_.toString)
+    testUtils.createTopic(topic)
+    testUtils.sendMessages(topic, data.toArray)
+
+    val kafkaParams = getKafkaParams()
+
+    val numThreads = 100
+    val numConsumerUsages = 500
+
+    @volatile var error: Throwable = null
+
+    def consume(i: Int): Unit = {
+      val useCache = Random.nextBoolean
+      val taskContext = if (Random.nextBoolean) {
+        new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null)
+      } else {
+        null
+      }
+      val consumer = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+        topicPartition, kafkaParams, taskContext, useCache)
+      try {
+        val rcvd = (0 until data.length).map { offset =>
+          val bytes = consumer.get(offset, 10000).value()
+          new String(bytes)
+        }
+        assert(rcvd == data)
+      } catch {
+        case e: Throwable =>
+          error = e
+          throw e
+      } finally {
+        consumer.release()
+      }
+    }
+
+    val threadPool = Executors.newFixedThreadPool(numThreads)
+    try {
+      val futures = (1 to numConsumerUsages).map { i =>
+        threadPool.submit(new Runnable {
+          override def run(): Unit = { consume(i) }
+        })
+      }
+      futures.foreach(_.get(1, TimeUnit.MINUTES))
+      assert(error == null)
+    } finally {
+      threadPool.shutdown()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org