You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/02/05 03:38:03 UTC

spark git commit: [SPARK-12850][SQL] Support Bucket Pruning (Predicate Pushdown for Bucketed Tables)

Repository: spark
Updated Branches:
  refs/heads/master 6dbfc4077 -> e3c75c639


[SPARK-12850][SQL] Support Bucket Pruning (Predicate Pushdown for Bucketed Tables)

JIRA: https://issues.apache.org/jira/browse/SPARK-12850

This PR is to support bucket pruning when the predicates are `EqualTo`, `EqualNullSafe`, `IsNull`, `In`, and `InSet`.

Like HIVE, in this PR, the bucket pruning works when the bucketing key has one and only one column.

So far, I do not find a way to verify how many buckets are actually scanned. However, I did verify it when doing the debug. Could you provide a suggestion how to do it properly? Thank you! cloud-fan yhuai rxin marmbrus

BTW, we can add more cases to support complex predicate including `Or` and `And`. Please let me know if I should do it in this PR.

Maybe we also need to add test cases to verify if bucket pruning works well for each data type.

Author: gatorsmile <ga...@gmail.com>

Closes #10942 from gatorsmile/pruningBuckets.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e3c75c63
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e3c75c63
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e3c75c63

Branch: refs/heads/master
Commit: e3c75c6398b1241500343ff237e9bcf78b5396f9
Parents: 6dbfc40
Author: gatorsmile <ga...@gmail.com>
Authored: Thu Feb 4 18:37:58 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Feb 4 18:37:58 2016 -0800

----------------------------------------------------------------------
 .../datasources/DataSourceStrategy.scala        |  78 ++++++++-
 .../apache/spark/sql/sources/interfaces.scala   |  15 +-
 .../spark/sql/sources/BucketedReadSuite.scala   | 162 ++++++++++++++++++-
 3 files changed, 245 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e3c75c63/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index da9320f..c24967a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -29,12 +29,14 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS}
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.{StringType, StructType}
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util.collection.BitSet
 
 /**
  * A Strategy for planning scans over data sources defined using the sources API.
@@ -97,10 +99,15 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
         (partitionAndNormalColumnAttrs ++ projects).toSeq
       }
 
+      // Prune the buckets based on the pushed filters that do not contain partitioning key
+      // since the bucketing key is not allowed to use the columns in partitioning key
+      val bucketSet = getBuckets(pushedFilters, t.getBucketSpec)
+
       val scan = buildPartitionedTableScan(
         l,
         partitionAndNormalColumnProjs,
         pushedFilters,
+        bucketSet,
         t.partitionSpec.partitionColumns,
         selectedPartitions)
 
@@ -124,11 +131,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
       val sharedHadoopConf = SparkHadoopUtil.get.conf
       val confBroadcast =
         t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf))
+      // Prune the buckets based on the filters
+      val bucketSet = getBuckets(filters, t.getBucketSpec)
       pruneFilterProject(
         l,
         projects,
         filters,
-        (a, f) => t.buildInternalScan(a.map(_.name).toArray, f, t.paths, confBroadcast)) :: Nil
+        (a, f) =>
+          t.buildInternalScan(a.map(_.name).toArray, f, bucketSet, t.paths, confBroadcast)) :: Nil
 
     case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
       execution.PhysicalRDD.createFromDataSource(
@@ -150,6 +160,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
       logicalRelation: LogicalRelation,
       projections: Seq[NamedExpression],
       filters: Seq[Expression],
+      buckets: Option[BitSet],
       partitionColumns: StructType,
       partitions: Array[Partition]): SparkPlan = {
     val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation]
@@ -174,7 +185,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
           // assuming partition columns data stored in data files are always consistent with those
           // partition values encoded in partition directory paths.
           val dataRows = relation.buildInternalScan(
-            requiredDataColumns.map(_.name).toArray, filters, Array(dir), confBroadcast)
+            requiredDataColumns.map(_.name).toArray, filters, buckets, Array(dir), confBroadcast)
 
           // Merges data values with partition values.
           mergeWithPartitionValues(
@@ -251,6 +262,69 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
     }
   }
 
+  // Get the bucket ID based on the bucketing values.
+  // Restriction: Bucket pruning works iff the bucketing column has one and only one column.
+  def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
+    val mutableRow = new SpecificMutableRow(Seq(bucketColumn.dataType))
+    mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null)
+    val bucketIdGeneration = UnsafeProjection.create(
+      HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
+      bucketColumn :: Nil)
+
+    bucketIdGeneration(mutableRow).getInt(0)
+  }
+
+  // Get the bucket BitSet by reading the filters that only contains bucketing keys.
+  // Note: When the returned BitSet is None, no pruning is possible.
+  // Restriction: Bucket pruning works iff the bucketing column has one and only one column.
+  private def getBuckets(
+      filters: Seq[Expression],
+      bucketSpec: Option[BucketSpec]): Option[BitSet] = {
+
+    if (bucketSpec.isEmpty ||
+      bucketSpec.get.numBuckets == 1 ||
+      bucketSpec.get.bucketColumnNames.length != 1) {
+      // None means all the buckets need to be scanned
+      return None
+    }
+
+    // Just get the first because bucketing pruning only works when the column has one column
+    val bucketColumnName = bucketSpec.get.bucketColumnNames.head
+    val numBuckets = bucketSpec.get.numBuckets
+    val matchedBuckets = new BitSet(numBuckets)
+    matchedBuckets.clear()
+
+    filters.foreach {
+      case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
+        matchedBuckets.set(getBucketId(a, numBuckets, v))
+      case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
+        matchedBuckets.set(getBucketId(a, numBuckets, v))
+      case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
+        matchedBuckets.set(getBucketId(a, numBuckets, v))
+      case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
+        matchedBuckets.set(getBucketId(a, numBuckets, v))
+      // Because we only convert In to InSet in Optimizer when there are more than certain
+      // items. So it is possible we still get an In expression here that needs to be pushed
+      // down.
+      case expressions.In(a: Attribute, list)
+          if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
+        val hSet = list.map(e => e.eval(EmptyRow))
+        hSet.foreach(e => matchedBuckets.set(getBucketId(a, numBuckets, e)))
+      case expressions.IsNull(a: Attribute) if a.name == bucketColumnName =>
+        matchedBuckets.set(getBucketId(a, numBuckets, null))
+      case _ =>
+    }
+
+    logInfo {
+      val selected = matchedBuckets.cardinality()
+      val percentPruned = (1 - selected.toDouble / numBuckets.toDouble) * 100
+      s"Selected $selected buckets out of $numBuckets, pruned $percentPruned% partitions."
+    }
+
+    // None means all the buckets need to be scanned
+    if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets)
+  }
+
   protected def prunePartitions(
       predicates: Seq[Expression],
       partitionSpec: PartitionSpec): Seq[Partition] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/e3c75c63/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 299fc6e..737be7d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.streaming.{Sink, Source}
 import org.apache.spark.sql.types.{StringType, StructType}
 import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.util.collection.BitSet
 
 /**
  * ::DeveloperApi::
@@ -722,6 +723,7 @@ abstract class HadoopFsRelation private[sql](
   final private[sql] def buildInternalScan(
       requiredColumns: Array[String],
       filters: Array[Filter],
+      bucketSet: Option[BitSet],
       inputPaths: Array[String],
       broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = {
     val inputStatuses = inputPaths.flatMap { input =>
@@ -743,9 +745,16 @@ abstract class HadoopFsRelation private[sql](
       // id from file name. Then read these files into a RDD(use one-partition empty RDD for empty
       // bucket), and coalesce it to one partition. Finally union all bucket RDDs to one result.
       val perBucketRows = (0 until maybeBucketSpec.get.numBuckets).map { bucketId =>
-        groupedBucketFiles.get(bucketId).map { inputStatuses =>
-          buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
-        }.getOrElse(sqlContext.emptyResult)
+        // If the current bucketId is not set in the bucket bitSet, skip scanning it.
+        if (bucketSet.nonEmpty && !bucketSet.get.get(bucketId)){
+          sqlContext.emptyResult
+        } else {
+          // When all the buckets need a scan (i.e., bucketSet is equal to None)
+          // or when the current bucket need a scan (i.e., the bit of bucketId is set to true)
+          groupedBucketFiles.get(bucketId).map { inputStatuses =>
+            buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
+          }.getOrElse(sqlContext.emptyResult)
+        }
       }
 
       new UnionRDD(sqlContext.sparkContext, perBucketRows)

http://git-wip-us.apache.org/repos/asf/spark/blob/e3c75c63/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 150d0c7..9ba6456 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -19,22 +19,28 @@ package org.apache.spark.sql.sources
 
 import java.io.File
 
-import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf}
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.Exchange
-import org.apache.spark.sql.execution.datasources.BucketSpec
+import org.apache.spark.sql.execution.{Exchange, PhysicalRDD}
+import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy}
 import org.apache.spark.sql.execution.joins.SortMergeJoin
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.BitSet
 
 class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
   import testImplicits._
 
+  private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+  private val nullDF = (for {
+    i <- 0 to 50
+    s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g")
+  } yield (i % 5, s, i % 13)).toDF("i", "j", "k")
+
   test("read bucketed data") {
-    val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
     withTable("bucketed_table") {
       df.write
         .format("parquet")
@@ -59,6 +65,152 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
     }
   }
 
+  // To verify if the bucket pruning works, this function checks two conditions:
+  //   1) Check if the pruned buckets (before filtering) are empty.
+  //   2) Verify the final result is the same as the expected one
+  private def checkPrunedAnswers(
+      bucketSpec: BucketSpec,
+      bucketValues: Seq[Integer],
+      filterCondition: Column,
+      originalDataFrame: DataFrame): Unit = {
+
+    val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k")
+    val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec
+    // Limit: bucket pruning only works when the bucket column has one and only one column
+    assert(bucketColumnNames.length == 1)
+    val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head)
+    val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex)
+    val matchedBuckets = new BitSet(numBuckets)
+    bucketValues.foreach { value =>
+      matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value))
+    }
+
+    // Filter could hide the bug in bucket pruning. Thus, skipping all the filters
+    val rdd = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan
+      .find(_.isInstanceOf[PhysicalRDD])
+    assert(rdd.isDefined)
+
+    val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
+      if (matchedBuckets.get(index % numBuckets)) Iterator(true) else Iterator(iter.isEmpty)
+    }
+    // checking if all the pruned buckets are empty
+    assert(checkedResult.collect().forall(_ == true))
+
+    checkAnswer(
+      bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"),
+      originalDataFrame.filter(filterCondition).orderBy("i", "j", "k"))
+  }
+
+  test("read partitioning bucketed tables with bucket pruning filters") {
+    withTable("bucketed_table") {
+      val numBuckets = 8
+      val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+      // json does not support predicate push-down, and thus json is used here
+      df.write
+        .format("json")
+        .partitionBy("i")
+        .bucketBy(numBuckets, "j")
+        .saveAsTable("bucketed_table")
+
+      for (j <- 0 until 13) {
+        // Case 1: EqualTo
+        checkPrunedAnswers(
+          bucketSpec,
+          bucketValues = j :: Nil,
+          filterCondition = $"j" === j,
+          df)
+
+        // Case 2: EqualNullSafe
+        checkPrunedAnswers(
+          bucketSpec,
+          bucketValues = j :: Nil,
+          filterCondition = $"j" <=> j,
+          df)
+
+        // Case 3: In
+        checkPrunedAnswers(
+          bucketSpec,
+          bucketValues = Seq(j, j + 1, j + 2, j + 3),
+          filterCondition = $"j".isin(j, j + 1, j + 2, j + 3),
+          df)
+      }
+    }
+  }
+
+  test("read non-partitioning bucketed tables with bucket pruning filters") {
+    withTable("bucketed_table") {
+      val numBuckets = 8
+      val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+      // json does not support predicate push-down, and thus json is used here
+      df.write
+        .format("json")
+        .bucketBy(numBuckets, "j")
+        .saveAsTable("bucketed_table")
+
+      for (j <- 0 until 13) {
+        checkPrunedAnswers(
+          bucketSpec,
+          bucketValues = j :: Nil,
+          filterCondition = $"j" === j,
+          df)
+      }
+    }
+  }
+
+  test("read partitioning bucketed tables having null in bucketing key") {
+    withTable("bucketed_table") {
+      val numBuckets = 8
+      val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+      // json does not support predicate push-down, and thus json is used here
+      nullDF.write
+        .format("json")
+        .partitionBy("i")
+        .bucketBy(numBuckets, "j")
+        .saveAsTable("bucketed_table")
+
+      // Case 1: isNull
+      checkPrunedAnswers(
+        bucketSpec,
+        bucketValues = null :: Nil,
+        filterCondition = $"j".isNull,
+        nullDF)
+
+      // Case 2: <=> null
+      checkPrunedAnswers(
+        bucketSpec,
+        bucketValues = null :: Nil,
+        filterCondition = $"j" <=> null,
+        nullDF)
+    }
+  }
+
+  test("read partitioning bucketed tables having composite filters") {
+    withTable("bucketed_table") {
+      val numBuckets = 8
+      val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+      // json does not support predicate push-down, and thus json is used here
+      df.write
+        .format("json")
+        .partitionBy("i")
+        .bucketBy(numBuckets, "j")
+        .saveAsTable("bucketed_table")
+
+      for (j <- 0 until 13) {
+        checkPrunedAnswers(
+          bucketSpec,
+          bucketValues = j :: Nil,
+          filterCondition = $"j" === j && $"k" > $"j",
+          df)
+
+        checkPrunedAnswers(
+          bucketSpec,
+          bucketValues = j :: Nil,
+          filterCondition = $"j" === j && $"i" > j % 5,
+          df)
+      }
+    }
+  }
+
   private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
   private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org