You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/07/30 05:20:10 UTC
spark git commit: [SPARK-24934][SQL] Explicitly whitelist supported
types in upper/lower bounds for in-memory partition pruning
Repository: spark
Updated Branches:
refs/heads/master 65a4bc143 -> bfe60fcdb
[SPARK-24934][SQL] Explicitly whitelist supported types in upper/lower bounds for in-memory partition pruning
## What changes were proposed in this pull request?
Looks we intentionally set `null` for upper/lower bounds for complex types and don't use it. However, these look used in in-memory partition pruning, which ends up with incorrect results.
This PR proposes to explicitly whitelist the supported types.
```scala
val df = Seq(Array("a", "b"), Array("c", "d")).toDF("arrayCol")
df.cache().filter("arrayCol > array('a', 'b')").show()
```
```scala
val df = sql("select cast('a' as binary) as a")
df.cache().filter("a == cast('a' as binary)").show()
```
**Before:**
```
+--------+
|arrayCol|
+--------+
+--------+
```
```
+---+
| a|
+---+
+---+
```
**After:**
```
+--------+
|arrayCol|
+--------+
| [c, d]|
+--------+
```
```
+----+
| a|
+----+
|[61]|
+----+
```
## How was this patch tested?
Unit tests were added and manually tested.
Author: hyukjinkwon <gu...@apache.org>
Closes #21882 from HyukjinKwon/stats-filter.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bfe60fcd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bfe60fcd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bfe60fcd
Branch: refs/heads/master
Commit: bfe60fcdb49aa48534060c38e36e06119900140d
Parents: 65a4bc1
Author: hyukjinkwon <gu...@apache.org>
Authored: Mon Jul 30 13:20:03 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Jul 30 13:20:03 2018 +0800
----------------------------------------------------------------------
.../columnar/InMemoryTableScanExec.scala | 42 ++++++++++++++------
.../columnar/PartitionBatchPruningSuite.scala | 30 +++++++++++++-
2 files changed, 58 insertions(+), 14 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/bfe60fcd/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 997cf92..6012aba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -183,6 +183,18 @@ case class InMemoryTableScanExec(
private val stats = relation.partitionStatistics
private def statsFor(a: Attribute) = stats.forAttribute(a)
+ // Currently, only use statistics from atomic types except binary type only.
+ private object ExtractableLiteral {
+ def unapply(expr: Expression): Option[Literal] = expr match {
+ case lit: Literal => lit.dataType match {
+ case BinaryType => None
+ case _: AtomicType => Some(lit)
+ case _ => None
+ }
+ case _ => None
+ }
+ }
+
// Returned filter predicate should return false iff it is impossible for the input expression
// to evaluate to `true' based on statistics collected about this partition batch.
@transient lazy val buildFilter: PartialFunction[Expression, Expression] = {
@@ -194,33 +206,37 @@ case class InMemoryTableScanExec(
if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
buildFilter(lhs) || buildFilter(rhs)
- case EqualTo(a: AttributeReference, l: Literal) =>
+ case EqualTo(a: AttributeReference, ExtractableLiteral(l)) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
- case EqualTo(l: Literal, a: AttributeReference) =>
+ case EqualTo(ExtractableLiteral(l), a: AttributeReference) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
- case EqualNullSafe(a: AttributeReference, l: Literal) =>
+ case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
- case EqualNullSafe(l: Literal, a: AttributeReference) =>
+ case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
- case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
- case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
+ case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l
+ case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound
- case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l
- case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound
+ case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
+ statsFor(a).lowerBound <= l
+ case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
+ l <= statsFor(a).upperBound
- case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound
- case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l
+ case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound
+ case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l
- case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound
- case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l
+ case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
+ l <= statsFor(a).upperBound
+ case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
+ statsFor(a).lowerBound <= l
case IsNull(a: Attribute) => statsFor(a).nullCount > 0
case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
case In(a: AttributeReference, list: Seq[Expression])
- if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty =>
+ if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty =>
list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] &&
l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bfe60fcd/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index 9d862cf..af493e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
@@ -35,6 +36,12 @@ class PartitionBatchPruningSuite
private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE)
private lazy val originalInMemoryPartitionPruning =
spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING)
+ private val testArrayData = (1 to 100).map { key =>
+ Tuple1(Array.fill(key)(key))
+ }
+ private val testBinaryData = (1 to 100).map { key =>
+ Tuple1(Array.fill(key)(key.toByte))
+ }
override protected def beforeAll(): Unit = {
super.beforeAll()
@@ -71,12 +78,22 @@ class PartitionBatchPruningSuite
}, 5).toDF()
pruningStringData.createOrReplaceTempView("pruningStringData")
spark.catalog.cacheTable("pruningStringData")
+
+ val pruningArrayData = sparkContext.makeRDD(testArrayData, 5).toDF()
+ pruningArrayData.createOrReplaceTempView("pruningArrayData")
+ spark.catalog.cacheTable("pruningArrayData")
+
+ val pruningBinaryData = sparkContext.makeRDD(testBinaryData, 5).toDF()
+ pruningBinaryData.createOrReplaceTempView("pruningBinaryData")
+ spark.catalog.cacheTable("pruningBinaryData")
}
override protected def afterEach(): Unit = {
try {
spark.catalog.uncacheTable("pruningData")
spark.catalog.uncacheTable("pruningStringData")
+ spark.catalog.uncacheTable("pruningArrayData")
+ spark.catalog.uncacheTable("pruningBinaryData")
} finally {
super.afterEach()
}
@@ -95,6 +112,14 @@ class PartitionBatchPruningSuite
checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11)
checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100)
checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100)
+ // Do not filter on array type
+ checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 = array(1)", 5, 10)(Seq(Array(1)))
+ checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 <= array(1)", 5, 10)(Seq(Array(1)))
+ checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 >= array(1)", 5, 10)(
+ testArrayData.map(_._1))
+ // Do not filter on binary type
+ checkBatchPruning(
+ "SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", 5, 10)(Seq(Array(1.toByte)))
// IS NULL
checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) {
@@ -131,6 +156,9 @@ class PartitionBatchPruningSuite
checkBatchPruning(
"SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)(
Seq(150))
+ // Do not filter on array type
+ checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 IN (array(1), array(2, 2))", 5, 10)(
+ Seq(Array(1), Array(2, 2)))
// With unsupported `InSet` predicate
{
@@ -161,7 +189,7 @@ class PartitionBatchPruningSuite
query: String,
expectedReadPartitions: Int,
expectedReadBatches: Int)(
- expectedQueryResult: => Seq[Int]): Unit = {
+ expectedQueryResult: => Seq[Any]): Unit = {
test(query) {
val df = sql(query)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org