You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/07/30 05:20:10 UTC

spark git commit: [SPARK-24934][SQL] Explicitly whitelist supported types in upper/lower bounds for in-memory partition pruning

Repository: spark
Updated Branches:
  refs/heads/master 65a4bc143 -> bfe60fcdb


[SPARK-24934][SQL] Explicitly whitelist supported types in upper/lower bounds for in-memory partition pruning

## What changes were proposed in this pull request?

Looks we intentionally set `null` for upper/lower bounds for complex types and don't use it. However, these look used in in-memory partition pruning, which ends up with incorrect results.

This PR proposes to explicitly whitelist the supported types.

```scala
val df = Seq(Array("a", "b"), Array("c", "d")).toDF("arrayCol")
df.cache().filter("arrayCol > array('a', 'b')").show()
```

```scala
val df = sql("select cast('a' as binary) as a")
df.cache().filter("a == cast('a' as binary)").show()
```

**Before:**

```
+--------+
|arrayCol|
+--------+
+--------+
```

```
+---+
|  a|
+---+
+---+
```

**After:**

```
+--------+
|arrayCol|
+--------+
|  [c, d]|
+--------+
```

```
+----+
|   a|
+----+
|[61]|
+----+
```

## How was this patch tested?

Unit tests were added and manually tested.

Author: hyukjinkwon <gu...@apache.org>

Closes #21882 from HyukjinKwon/stats-filter.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bfe60fcd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bfe60fcd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bfe60fcd

Branch: refs/heads/master
Commit: bfe60fcdb49aa48534060c38e36e06119900140d
Parents: 65a4bc1
Author: hyukjinkwon <gu...@apache.org>
Authored: Mon Jul 30 13:20:03 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Jul 30 13:20:03 2018 +0800

----------------------------------------------------------------------
 .../columnar/InMemoryTableScanExec.scala        | 42 ++++++++++++++------
 .../columnar/PartitionBatchPruningSuite.scala   | 30 +++++++++++++-
 2 files changed, 58 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bfe60fcd/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 997cf92..6012aba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -183,6 +183,18 @@ case class InMemoryTableScanExec(
   private val stats = relation.partitionStatistics
   private def statsFor(a: Attribute) = stats.forAttribute(a)
 
+  // Currently, only use statistics from atomic types except binary type only.
+  private object ExtractableLiteral {
+    def unapply(expr: Expression): Option[Literal] = expr match {
+      case lit: Literal => lit.dataType match {
+        case BinaryType => None
+        case _: AtomicType => Some(lit)
+        case _ => None
+      }
+      case _ => None
+    }
+  }
+
   // Returned filter predicate should return false iff it is impossible for the input expression
   // to evaluate to `true' based on statistics collected about this partition batch.
   @transient lazy val buildFilter: PartialFunction[Expression, Expression] = {
@@ -194,33 +206,37 @@ case class InMemoryTableScanExec(
       if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
       buildFilter(lhs) || buildFilter(rhs)
 
-    case EqualTo(a: AttributeReference, l: Literal) =>
+    case EqualTo(a: AttributeReference, ExtractableLiteral(l)) =>
       statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
-    case EqualTo(l: Literal, a: AttributeReference) =>
+    case EqualTo(ExtractableLiteral(l), a: AttributeReference) =>
       statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
 
-    case EqualNullSafe(a: AttributeReference, l: Literal) =>
+    case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) =>
       statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
-    case EqualNullSafe(l: Literal, a: AttributeReference) =>
+    case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) =>
       statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
 
-    case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
-    case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
+    case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l
+    case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound
 
-    case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l
-    case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound
+    case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
+      statsFor(a).lowerBound <= l
+    case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
+      l <= statsFor(a).upperBound
 
-    case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound
-    case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l
+    case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound
+    case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l
 
-    case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound
-    case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l
+    case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
+      l <= statsFor(a).upperBound
+    case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
+      statsFor(a).lowerBound <= l
 
     case IsNull(a: Attribute) => statsFor(a).nullCount > 0
     case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
 
     case In(a: AttributeReference, list: Seq[Expression])
-      if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty =>
+      if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty =>
       list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] &&
         l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/bfe60fcd/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index 9d862cf..af493e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar
 import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.test.SQLTestData._
@@ -35,6 +36,12 @@ class PartitionBatchPruningSuite
   private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE)
   private lazy val originalInMemoryPartitionPruning =
     spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING)
+  private val testArrayData = (1 to 100).map { key =>
+    Tuple1(Array.fill(key)(key))
+  }
+  private val testBinaryData = (1 to 100).map { key =>
+    Tuple1(Array.fill(key)(key.toByte))
+  }
 
   override protected def beforeAll(): Unit = {
     super.beforeAll()
@@ -71,12 +78,22 @@ class PartitionBatchPruningSuite
     }, 5).toDF()
     pruningStringData.createOrReplaceTempView("pruningStringData")
     spark.catalog.cacheTable("pruningStringData")
+
+    val pruningArrayData = sparkContext.makeRDD(testArrayData, 5).toDF()
+    pruningArrayData.createOrReplaceTempView("pruningArrayData")
+    spark.catalog.cacheTable("pruningArrayData")
+
+    val pruningBinaryData = sparkContext.makeRDD(testBinaryData, 5).toDF()
+    pruningBinaryData.createOrReplaceTempView("pruningBinaryData")
+    spark.catalog.cacheTable("pruningBinaryData")
   }
 
   override protected def afterEach(): Unit = {
     try {
       spark.catalog.uncacheTable("pruningData")
       spark.catalog.uncacheTable("pruningStringData")
+      spark.catalog.uncacheTable("pruningArrayData")
+      spark.catalog.uncacheTable("pruningBinaryData")
     } finally {
       super.afterEach()
     }
@@ -95,6 +112,14 @@ class PartitionBatchPruningSuite
   checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11)
   checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100)
   checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100)
+  // Do not filter on array type
+  checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 = array(1)", 5, 10)(Seq(Array(1)))
+  checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 <= array(1)", 5, 10)(Seq(Array(1)))
+  checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 >= array(1)", 5, 10)(
+    testArrayData.map(_._1))
+  // Do not filter on binary type
+  checkBatchPruning(
+    "SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", 5, 10)(Seq(Array(1.toByte)))
 
   // IS NULL
   checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) {
@@ -131,6 +156,9 @@ class PartitionBatchPruningSuite
   checkBatchPruning(
     "SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)(
       Seq(150))
+  // Do not filter on array type
+  checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 IN (array(1), array(2, 2))", 5, 10)(
+    Seq(Array(1), Array(2, 2)))
 
   // With unsupported `InSet` predicate
   {
@@ -161,7 +189,7 @@ class PartitionBatchPruningSuite
       query: String,
       expectedReadPartitions: Int,
       expectedReadBatches: Int)(
-      expectedQueryResult: => Seq[Int]): Unit = {
+      expectedQueryResult: => Seq[Any]): Unit = {
 
     test(query) {
       val df = sql(query)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org