You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/03/03 02:14:17 UTC

spark git commit: [SPARK-23541][SS] Allow Kafka source to read data with greater parallelism than the number of topic-partitions

Repository: spark
Updated Branches:
  refs/heads/master dea381dfa -> 486f99eef


[SPARK-23541][SS] Allow Kafka source to read data with greater parallelism than the number of topic-partitions

## What changes were proposed in this pull request?

Currently, when the Kafka source reads from Kafka, it generates as many tasks as the number of partitions in the topic(s) to be read. In some case, it may be beneficial to read the data with greater parallelism, that is, with more number partitions/tasks. That means, offset ranges must be divided up into smaller ranges such the number of records in partition ~= total records in batch / desired partitions. This would also balance out any data skews between topic-partitions.

In this patch, I have added a new option called `minPartitions`, which allows the user to specify the desired level of parallelism.

## How was this patch tested?
New tests in KafkaMicroBatchV2SourceSuite.

Author: Tathagata Das <ta...@gmail.com>

Closes #20698 from tdas/SPARK-23541.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/486f99ee
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/486f99ee
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/486f99ee

Branch: refs/heads/master
Commit: 486f99eefead4e664a30a861eca65cab8568e70b
Parents: dea381d
Author: Tathagata Das <ta...@gmail.com>
Authored: Fri Mar 2 18:14:13 2018 -0800
Committer: Tathagata Das <ta...@gmail.com>
Committed: Fri Mar 2 18:14:13 2018 -0800

----------------------------------------------------------------------
 .../sql/kafka010/KafkaMicroBatchReader.scala    | 109 +++++++-------
 .../kafka010/KafkaOffsetRangeCalculator.scala   | 105 +++++++++++++
 .../sql/kafka010/KafkaSourceProvider.scala      |   7 +
 .../org/apache/spark/sql/kafka010/package.scala |  24 +++
 .../kafka010/KafkaMicroBatchSourceSuite.scala   |  56 ++++++-
 .../KafkaOffsetRangeCalculatorSuite.scala       | 147 +++++++++++++++++++
 6 files changed, 388 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/486f99ee/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
index fb647ca..8a5f3a2 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
@@ -24,7 +24,6 @@ import java.nio.charset.StandardCharsets
 import scala.collection.JavaConverters._
 
 import org.apache.commons.io.IOUtils
-import org.apache.kafka.common.TopicPartition
 
 import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
@@ -64,8 +63,6 @@ private[kafka010] class KafkaMicroBatchReader(
     failOnDataLoss: Boolean)
   extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
 
-  type PartitionOffsetMap = Map[TopicPartition, Long]
-
   private var startPartitionOffsets: PartitionOffsetMap = _
   private var endPartitionOffsets: PartitionOffsetMap = _
 
@@ -76,6 +73,7 @@ private[kafka010] class KafkaMicroBatchReader(
   private val maxOffsetsPerTrigger =
     Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong)
 
+  private val rangeCalculator = KafkaOffsetRangeCalculator(options)
   /**
    * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
    * called in StreamExecutionThread. Otherwise, interrupting a thread while running
@@ -106,15 +104,15 @@ private[kafka010] class KafkaMicroBatchReader(
   override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
     // Find the new partitions, and get their earliest offsets
     val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
-    val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
-    if (newPartitionOffsets.keySet != newPartitions) {
+    val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
+    if (newPartitionInitialOffsets.keySet != newPartitions) {
       // We cannot get from offsets for some partitions. It means they got deleted.
-      val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet)
+      val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet)
       reportDataLoss(
         s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed")
     }
-    logInfo(s"Partitions added: $newPartitionOffsets")
-    newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) =>
+    logInfo(s"Partitions added: $newPartitionInitialOffsets")
+    newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) =>
       reportDataLoss(
         s"Added partition $p starts from $o instead of 0. Some data may have been missed")
     }
@@ -125,46 +123,28 @@ private[kafka010] class KafkaMicroBatchReader(
       reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed")
     }
 
-    // Use the until partitions to calculate offset ranges to ignore partitions that have
+    // Use the end partitions to calculate offset ranges to ignore partitions that have
     // been deleted
     val topicPartitions = endPartitionOffsets.keySet.filter { tp =>
       // Ignore partitions that we don't know the from offsets.
-      newPartitionOffsets.contains(tp) || startPartitionOffsets.contains(tp)
+      newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp)
     }.toSeq
     logDebug("TopicPartitions: " + topicPartitions.mkString(", "))
 
-    val sortedExecutors = getSortedExecutorList()
-    val numExecutors = sortedExecutors.length
-    logDebug("Sorted executors: " + sortedExecutors.mkString(", "))
-
     // Calculate offset ranges
-    val factories = topicPartitions.flatMap { tp =>
-      val fromOffset = startPartitionOffsets.get(tp).getOrElse {
-        newPartitionOffsets.getOrElse(
-        tp, {
-          // This should not happen since newPartitionOffsets contains all partitions not in
-          // fromPartitionOffsets
-          throw new IllegalStateException(s"$tp doesn't have a from offset")
-        })
-      }
-      val untilOffset = endPartitionOffsets(tp)
-
-      if (untilOffset >= fromOffset) {
-        // This allows cached KafkaConsumers in the executors to be re-used to read the same
-        // partition in every batch.
-        val preferredLoc = if (numExecutors > 0) {
-          Some(sortedExecutors(Math.floorMod(tp.hashCode, numExecutors)))
-        } else None
-        val range = KafkaOffsetRange(tp, fromOffset, untilOffset)
-        Some(
-          new KafkaMicroBatchDataReaderFactory(
-            range, preferredLoc, executorKafkaParams, pollTimeoutMs, failOnDataLoss))
-      } else {
-        reportDataLoss(
-          s"Partition $tp's offset was changed from " +
-            s"$fromOffset to $untilOffset, some data may have been missed")
-        None
-      }
+    val offsetRanges = rangeCalculator.getRanges(
+      fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets,
+      untilOffsets = endPartitionOffsets,
+      executorLocations = getSortedExecutorList())
+
+    // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions,
+    // that is, concurrent tasks will not read the same TopicPartitions.
+    val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size
+
+    // Generate factories based on the offset ranges
+    val factories = offsetRanges.map { range =>
+      new KafkaMicroBatchDataReaderFactory(
+        range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
     }
     factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava
   }
@@ -320,28 +300,39 @@ private[kafka010] class KafkaMicroBatchReader(
 }
 
 /** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */
-private[kafka010] class KafkaMicroBatchDataReaderFactory(
-    range: KafkaOffsetRange,
-    preferredLoc: Option[String],
+private[kafka010] case class KafkaMicroBatchDataReaderFactory(
+    offsetRange: KafkaOffsetRange,
     executorKafkaParams: ju.Map[String, Object],
     pollTimeoutMs: Long,
-    failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] {
+    failOnDataLoss: Boolean,
+    reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] {
 
-  override def preferredLocations(): Array[String] = preferredLoc.toArray
+  override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray
 
   override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader(
-    range, executorKafkaParams, pollTimeoutMs, failOnDataLoss)
+    offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
 }
 
 /** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */
-private[kafka010] class KafkaMicroBatchDataReader(
+private[kafka010] case class KafkaMicroBatchDataReader(
     offsetRange: KafkaOffsetRange,
     executorKafkaParams: ju.Map[String, Object],
     pollTimeoutMs: Long,
-    failOnDataLoss: Boolean) extends DataReader[UnsafeRow] with Logging {
+    failOnDataLoss: Boolean,
+    reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging {
+
+  private val consumer = {
+    if (!reuseKafkaConsumer) {
+      // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. We
+      // uses `assign` here, hence we don't need to worry about the "group.id" conflicts.
+      CachedKafkaConsumer.createUncached(
+        offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams)
+    } else {
+      CachedKafkaConsumer.getOrCreate(
+        offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams)
+    }
+  }
 
-  private val consumer = CachedKafkaConsumer.getOrCreate(
-    offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams)
   private val rangeToRead = resolveRange(offsetRange)
   private val converter = new KafkaRecordToUnsafeRowConverter
 
@@ -369,9 +360,14 @@ private[kafka010] class KafkaMicroBatchDataReader(
   }
 
   override def close(): Unit = {
-    // Indicate that we're no longer using this consumer
-    CachedKafkaConsumer.releaseKafkaConsumer(
-      offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams)
+    if (!reuseKafkaConsumer) {
+      // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster!
+      consumer.close()
+    } else {
+      // Indicate that we're no longer using this consumer
+      CachedKafkaConsumer.releaseKafkaConsumer(
+        offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams)
+    }
   }
 
   private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = {
@@ -392,12 +388,9 @@ private[kafka010] class KafkaMicroBatchDataReader(
       } else {
         range.untilOffset
       }
-      KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset)
+      KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset, None)
     } else {
       range
     }
   }
 }
-
-private[kafka010] case class KafkaOffsetRange(
-  topicPartition: TopicPartition, fromOffset: Long, untilOffset: Long)

http://git-wip-us.apache.org/repos/asf/spark/blob/486f99ee/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala
new file mode 100644
index 0000000..6631ae8
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+
+
+/**
+ * Class to calculate offset ranges to process based on the the from and until offsets, and
+ * the configured `minPartitions`.
+ */
+private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) {
+  require(minPartitions.isEmpty || minPartitions.get > 0)
+
+  import KafkaOffsetRangeCalculator._
+  /**
+   * Calculate the offset ranges that we are going to process this batch. If `minPartitions`
+   * is not set or is set less than or equal the number of `topicPartitions` that we're going to
+   * consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If
+   * `numPartitions` is set higher than the number of our `topicPartitions`, then we will split up
+   * the read tasks of the skewed partitions to multiple Spark tasks.
+   * The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more
+   * depending on rounding errors or Kafka partitions that didn't receive any new data.
+   */
+  def getRanges(
+      fromOffsets: PartitionOffsetMap,
+      untilOffsets: PartitionOffsetMap,
+      executorLocations: Seq[String] = Seq.empty): Seq[KafkaOffsetRange] = {
+    val partitionsToRead = untilOffsets.keySet.intersect(fromOffsets.keySet)
+
+    val offsetRanges = partitionsToRead.toSeq.map { tp =>
+      KafkaOffsetRange(tp, fromOffsets(tp), untilOffsets(tp), preferredLoc = None)
+    }.filter(_.size > 0)
+
+    // If minPartitions not set or there are enough partitions to satisfy minPartitions
+    if (minPartitions.isEmpty || offsetRanges.size > minPartitions.get) {
+      // Assign preferred executor locations to each range such that the same topic-partition is
+      // preferentially read from the same executor and the KafkaConsumer can be reused.
+      offsetRanges.map { range =>
+        range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations))
+      }
+    } else {
+
+      // Splits offset ranges with relatively large amount of data to smaller ones.
+      val totalSize = offsetRanges.map(_.size).sum
+      val idealRangeSize = totalSize.toDouble / minPartitions.get
+
+      offsetRanges.flatMap { range =>
+        // Split the current range into subranges as close to the ideal range size
+        val numSplitsInRange = math.round(range.size.toDouble / idealRangeSize).toInt
+
+        (0 until numSplitsInRange).map { i =>
+          val splitStart = range.fromOffset + range.size * (i.toDouble / numSplitsInRange)
+          val splitEnd = range.fromOffset + range.size * ((i.toDouble + 1) / numSplitsInRange)
+          KafkaOffsetRange(
+            range.topicPartition, splitStart.toLong, splitEnd.toLong, preferredLoc = None)
+        }
+      }
+    }
+  }
+
+  private def getLocation(tp: TopicPartition, executorLocations: Seq[String]): Option[String] = {
+    def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b
+
+    val numExecutors = executorLocations.length
+    if (numExecutors > 0) {
+      // This allows cached KafkaConsumers in the executors to be re-used to read the same
+      // partition in every batch.
+      Some(executorLocations(floorMod(tp.hashCode, numExecutors)))
+    } else None
+  }
+}
+
+private[kafka010] object KafkaOffsetRangeCalculator {
+
+  def apply(options: DataSourceOptions): KafkaOffsetRangeCalculator = {
+    val optionalValue = Option(options.get("minPartitions").orElse(null)).map(_.toInt)
+    new KafkaOffsetRangeCalculator(optionalValue)
+  }
+}
+
+private[kafka010] case class KafkaOffsetRange(
+    topicPartition: TopicPartition,
+    fromOffset: Long,
+    untilOffset: Long,
+    preferredLoc: Option[String]) {
+  lazy val size: Long = untilOffset - fromOffset
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/486f99ee/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index 0aa64a6..36b9f04 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -348,6 +348,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
         throw new IllegalArgumentException("Unknown option")
     }
 
+    // Validate minPartitions value if present
+    if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) {
+      val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt
+      if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive")
+    }
+
     // Validate user-specified Kafka options
 
     if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) {
@@ -455,6 +461,7 @@ private[kafka010] object KafkaSourceProvider extends Logging {
   private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets"
   private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
   private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss"
+  private val MIN_PARTITIONS_OPTION_KEY = "minpartitions"
 
   val TOPIC_OPTION_KEY = "topic"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/486f99ee/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala
new file mode 100644
index 0000000..43acd6a
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.kafka.common.TopicPartition
+
+package object kafka010 {   // scalastyle:ignore
+  // ^^ scalastyle:ignore is for ignoring warnings about digits in package name
+  type PartitionOffsetMap = Map[TopicPartition, Long]
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/486f99ee/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index 89c9ef4..f2b3ff7 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.kafka010
 import java.io._
 import java.nio.charset.StandardCharsets.UTF_8
 import java.nio.file.{Files, Paths}
-import java.util.{Locale, Properties}
+import java.util.{Locale, Optional, Properties}
 import java.util.concurrent.ConcurrentLinkedQueue
 import java.util.concurrent.atomic.AtomicInteger
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.io.Source
 import scala.util.Random
@@ -34,15 +35,19 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkContext
-import org.apache.spark.sql.{Dataset, ForeachWriter}
+import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update
 import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
 import org.apache.spark.sql.functions.{count, window}
 import org.apache.spark.sql.kafka010.KafkaSourceProvider._
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
 import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
+import org.apache.spark.sql.types.StructType
 
 abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
 
@@ -642,6 +647,53 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
       }
     )
   }
+
+  testWithUninterruptibleThread("minPartitions is supported") {
+    import testImplicits._
+
+    val topic = newTopic()
+    val tp = new TopicPartition(topic, 0)
+    testUtils.createTopic(topic, partitions = 1)
+
+    def test(
+        minPartitions: String,
+        numPartitionsGenerated: Int,
+        reusesConsumers: Boolean): Unit = {
+
+      SparkSession.setActiveSession(spark)
+      withTempDir { dir =>
+        val provider = new KafkaSourceProvider()
+        val options = Map(
+          "kafka.bootstrap.servers" -> testUtils.brokerAddress,
+          "subscribe" -> topic
+        ) ++ Option(minPartitions).map { p => "minPartitions" -> p}
+        val reader = provider.createMicroBatchReader(
+          Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava))
+        reader.setOffsetRange(
+          Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))),
+          Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L)))
+        )
+        val factories = reader.createUnsafeRowReaderFactories().asScala
+          .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory])
+        withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") {
+          assert(factories.size == numPartitionsGenerated)
+          factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) }
+        }
+      }
+    }
+
+    // Test cases when minPartitions is used and not used
+    test(minPartitions = null, numPartitionsGenerated = 1, reusesConsumers = true)
+    test(minPartitions = "1", numPartitionsGenerated = 1, reusesConsumers = true)
+    test(minPartitions = "4", numPartitionsGenerated = 4, reusesConsumers = false)
+
+    // Test illegal minPartitions values
+    intercept[IllegalArgumentException] { test(minPartitions = "a", 1, true) }
+    intercept[IllegalArgumentException] { test(minPartitions = "1.0", 1, true) }
+    intercept[IllegalArgumentException] { test(minPartitions = "0", 1, true) }
+    intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) }
+  }
+
 }
 
 abstract class KafkaSourceSuiteBase extends KafkaSourceTest {

http://git-wip-us.apache.org/repos/asf/spark/blob/486f99ee/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala
new file mode 100644
index 0000000..2ccf3e2
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import scala.collection.JavaConverters._
+
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+
+class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite {
+
+  def testWithMinPartitions(name: String, minPartition: Int)
+      (f: KafkaOffsetRangeCalculator => Unit): Unit = {
+    val options = new DataSourceOptions(Map("minPartitions" -> minPartition.toString).asJava)
+    test(s"with minPartition = $minPartition: $name") {
+      f(KafkaOffsetRangeCalculator(options))
+    }
+  }
+
+
+  test("with no minPartition: N TopicPartitions to N offset ranges") {
+    val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty())
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1),
+        untilOffsets = Map(tp1 -> 2)) ==
+      Seq(KafkaOffsetRange(tp1, 1, 2, None)))
+
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1),
+        untilOffsets = Map(tp1 -> 2, tp2 -> 1), Seq.empty) ==
+      Seq(KafkaOffsetRange(tp1, 1, 2, None)))
+
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1, tp2 -> 1),
+        untilOffsets = Map(tp1 -> 2)) ==
+      Seq(KafkaOffsetRange(tp1, 1, 2, None)))
+
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1, tp2 -> 1),
+        untilOffsets = Map(tp1 -> 2),
+        executorLocations = Seq("location")) ==
+      Seq(KafkaOffsetRange(tp1, 1, 2, Some("location"))))
+  }
+
+  test("with no minPartition: empty ranges ignored") {
+    val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty())
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1, tp2 -> 1),
+        untilOffsets = Map(tp1 -> 2, tp2 -> 1)) ==
+      Seq(KafkaOffsetRange(tp1, 1, 2, None)))
+  }
+
+  testWithMinPartitions("N TopicPartitions to N offset ranges", 3) { calc =>
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1),
+        untilOffsets = Map(tp1 -> 2, tp2 -> 2, tp3 -> 2)) ==
+      Seq(
+        KafkaOffsetRange(tp1, 1, 2, None),
+        KafkaOffsetRange(tp2, 1, 2, None),
+        KafkaOffsetRange(tp3, 1, 2, None)))
+  }
+
+  testWithMinPartitions("1 TopicPartition to N offset ranges", 4) { calc =>
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1),
+        untilOffsets = Map(tp1 -> 5)) ==
+      Seq(
+        KafkaOffsetRange(tp1, 1, 2, None),
+        KafkaOffsetRange(tp1, 2, 3, None),
+        KafkaOffsetRange(tp1, 3, 4, None),
+        KafkaOffsetRange(tp1, 4, 5, None)))
+
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1),
+        untilOffsets = Map(tp1 -> 5),
+        executorLocations = Seq("location")) ==
+        Seq(
+          KafkaOffsetRange(tp1, 1, 2, None),
+          KafkaOffsetRange(tp1, 2, 3, None),
+          KafkaOffsetRange(tp1, 3, 4, None),
+          KafkaOffsetRange(tp1, 4, 5, None))) // location pref not set when minPartition is set
+  }
+
+  testWithMinPartitions("N skewed TopicPartitions to M offset ranges", 3) { calc =>
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1, tp2 -> 1),
+        untilOffsets = Map(tp1 -> 5, tp2 -> 21)) ==
+        Seq(
+          KafkaOffsetRange(tp1, 1, 5, None),
+          KafkaOffsetRange(tp2, 1, 7, None),
+          KafkaOffsetRange(tp2, 7, 14, None),
+          KafkaOffsetRange(tp2, 14, 21, None)))
+  }
+
+  testWithMinPartitions("range inexact multiple of minPartitions", 3) { calc =>
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1),
+        untilOffsets = Map(tp1 -> 11)) ==
+        Seq(
+          KafkaOffsetRange(tp1, 1, 4, None),
+          KafkaOffsetRange(tp1, 4, 7, None),
+          KafkaOffsetRange(tp1, 7, 11, None)))
+  }
+
+  testWithMinPartitions("empty ranges ignored", 3) { calc =>
+    assert(
+      calc.getRanges(
+        fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1),
+        untilOffsets = Map(tp1 -> 5, tp2 -> 21, tp3 -> 1)) ==
+        Seq(
+          KafkaOffsetRange(tp1, 1, 5, None),
+          KafkaOffsetRange(tp2, 1, 7, None),
+          KafkaOffsetRange(tp2, 7, 14, None),
+          KafkaOffsetRange(tp2, 14, 21, None)))
+  }
+
+  private val tp1 = new TopicPartition("t1", 1)
+  private val tp2 = new TopicPartition("t2", 1)
+  private val tp3 = new TopicPartition("t3", 1)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org