You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/02/02 04:44:50 UTC
spark git commit: [SPARK-23301][SQL] data source column pruning
should work for arbitrary expressions
Repository: spark
Updated Branches:
refs/heads/master b3a04283f -> 19c7c7ebd
[SPARK-23301][SQL] data source column pruning should work for arbitrary expressions
## What changes were proposed in this pull request?
This PR fixes a mistake in the `PushDownOperatorsToDataSource` rule, the column pruning logic is incorrect about `Project`.
## How was this patch tested?
a new test case for column pruning with arbitrary expressions, and improve the existing tests to make sure the `PushDownOperatorsToDataSource` really works.
Author: Wenchen Fan <we...@databricks.com>
Closes #20476 from cloud-fan/push-down.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/19c7c7eb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/19c7c7eb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/19c7c7eb
Branch: refs/heads/master
Commit: 19c7c7ebdef6c1c7a02ebac9af6a24f521b52c37
Parents: b3a0428
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Feb 1 20:44:46 2018 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Feb 1 20:44:46 2018 -0800
----------------------------------------------------------------------
.../v2/PushDownOperatorsToDataSource.scala | 53 +++++----
.../sources/v2/JavaAdvancedDataSourceV2.java | 29 ++++-
.../sql/sources/v2/DataSourceV2Suite.scala | 113 +++++++++++++++++--
3 files changed, 155 insertions(+), 40 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/19c7c7eb/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
index df034ad..566a483 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper}
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper}
import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
@@ -81,35 +81,34 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel
// TODO: add more push down rules.
- // TODO: nested fields pruning
- def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = {
- plan match {
- case Project(projectList, child) =>
- val required = projectList.filter(requiredByParent.contains).flatMap(_.references)
- pushDownRequiredColumns(child, required)
-
- case Filter(condition, child) =>
- val required = requiredByParent ++ condition.references
- pushDownRequiredColumns(child, required)
-
- case DataSourceV2Relation(fullOutput, reader) => reader match {
- case r: SupportsPushDownRequiredColumns =>
- // Match original case of attributes.
- val attrMap = AttributeMap(fullOutput.zip(fullOutput))
- val requiredColumns = requiredByParent.map(attrMap)
- r.pruneColumns(requiredColumns.toStructType)
- case _ =>
- }
+ pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
+ // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
+ RemoveRedundantProject(filterPushed)
+ }
+
+ // TODO: nested fields pruning
+ private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = {
+ plan match {
+ case Project(projectList, child) =>
+ val required = projectList.flatMap(_.references)
+ pushDownRequiredColumns(child, AttributeSet(required))
+
+ case Filter(condition, child) =>
+ val required = requiredByParent ++ condition.references
+ pushDownRequiredColumns(child, required)
- // TODO: there may be more operators can be used to calculate required columns, we can add
- // more and more in the future.
- case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output))
+ case relation: DataSourceV2Relation => relation.reader match {
+ case reader: SupportsPushDownRequiredColumns =>
+ val requiredColumns = relation.output.filter(requiredByParent.contains)
+ reader.pruneColumns(requiredColumns.toStructType)
+
+ case _ =>
}
- }
- pushDownRequiredColumns(filterPushed, filterPushed.output)
- // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
- RemoveRedundantProject(filterPushed)
+ // TODO: there may be more operators that can be used to calculate the required columns. We
+ // can add more and more in the future.
+ case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet))
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/19c7c7eb/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
index d421f7d..172e5d5 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
@@ -32,11 +32,12 @@ import org.apache.spark.sql.types.StructType;
public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport {
- class Reader implements DataSourceReader, SupportsPushDownRequiredColumns,
+ public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns,
SupportsPushDownFilters {
- private StructType requiredSchema = new StructType().add("i", "int").add("j", "int");
- private Filter[] filters = new Filter[0];
+ // Exposed for testing.
+ public StructType requiredSchema = new StructType().add("i", "int").add("j", "int");
+ public Filter[] filters = new Filter[0];
@Override
public StructType readSchema() {
@@ -50,8 +51,26 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport {
@Override
public Filter[] pushFilters(Filter[] filters) {
- this.filters = filters;
- return new Filter[0];
+ Filter[] supported = Arrays.stream(filters).filter(f -> {
+ if (f instanceof GreaterThan) {
+ GreaterThan gt = (GreaterThan) f;
+ return gt.attribute().equals("i") && gt.value() instanceof Integer;
+ } else {
+ return false;
+ }
+ }).toArray(Filter[]::new);
+
+ Filter[] unsupported = Arrays.stream(filters).filter(f -> {
+ if (f instanceof GreaterThan) {
+ GreaterThan gt = (GreaterThan) f;
+ return !gt.attribute().equals("i") || !(gt.value() instanceof Integer);
+ } else {
+ return true;
+ }
+ }).toArray(Filter[]::new);
+
+ this.filters = supported;
+ return unsupported;
}
@Override
http://git-wip-us.apache.org/repos/asf/spark/blob/19c7c7eb/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index 23147ff..eccd454 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -21,11 +21,13 @@ import java.util.{ArrayList, List => JList}
import test.org.apache.spark.sql.sources.v2._
-import org.apache.spark.{SparkConf, SparkException}
-import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.sources.{Filter, GreaterThan}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning}
@@ -48,14 +50,72 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
test("advanced implementation") {
+ def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = {
+ query.queryExecution.executedPlan.collect {
+ case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader]
+ }.head
+ }
+
+ def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = {
+ query.queryExecution.executedPlan.collect {
+ case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader]
+ }.head
+ }
+
Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
checkAnswer(df, (0 until 10).map(i => Row(i, -i)))
- checkAnswer(df.select('j), (0 until 10).map(i => Row(-i)))
- checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i)))
- checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i)))
- checkAnswer(df.select('i).filter('i > 10), Nil)
+
+ val q1 = df.select('j)
+ checkAnswer(q1, (0 until 10).map(i => Row(-i)))
+ if (cls == classOf[AdvancedDataSourceV2]) {
+ val reader = getReader(q1)
+ assert(reader.filters.isEmpty)
+ assert(reader.requiredSchema.fieldNames === Seq("j"))
+ } else {
+ val reader = getJavaReader(q1)
+ assert(reader.filters.isEmpty)
+ assert(reader.requiredSchema.fieldNames === Seq("j"))
+ }
+
+ val q2 = df.filter('i > 3)
+ checkAnswer(q2, (4 until 10).map(i => Row(i, -i)))
+ if (cls == classOf[AdvancedDataSourceV2]) {
+ val reader = getReader(q2)
+ assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+ assert(reader.requiredSchema.fieldNames === Seq("i", "j"))
+ } else {
+ val reader = getJavaReader(q2)
+ assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+ assert(reader.requiredSchema.fieldNames === Seq("i", "j"))
+ }
+
+ val q3 = df.select('i).filter('i > 6)
+ checkAnswer(q3, (7 until 10).map(i => Row(i)))
+ if (cls == classOf[AdvancedDataSourceV2]) {
+ val reader = getReader(q3)
+ assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+ assert(reader.requiredSchema.fieldNames === Seq("i"))
+ } else {
+ val reader = getJavaReader(q3)
+ assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+ assert(reader.requiredSchema.fieldNames === Seq("i"))
+ }
+
+ val q4 = df.select('j).filter('j < -10)
+ checkAnswer(q4, Nil)
+ if (cls == classOf[AdvancedDataSourceV2]) {
+ val reader = getReader(q4)
+ // 'j < 10 is not supported by the testing data source.
+ assert(reader.filters.isEmpty)
+ assert(reader.requiredSchema.fieldNames === Seq("j"))
+ } else {
+ val reader = getJavaReader(q4)
+ // 'j < 10 is not supported by the testing data source.
+ assert(reader.filters.isEmpty)
+ assert(reader.requiredSchema.fieldNames === Seq("j"))
+ }
}
}
}
@@ -223,6 +283,39 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
val df2 = df.select(($"i" + 1).as("k"), $"j")
checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1)))
}
+
+ test("SPARK-23301: column pruning with arbitrary expressions") {
+ def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = {
+ query.queryExecution.executedPlan.collect {
+ case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader]
+ }.head
+ }
+
+ val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
+
+ val q1 = df.select('i + 1)
+ checkAnswer(q1, (1 until 11).map(i => Row(i)))
+ val reader1 = getReader(q1)
+ assert(reader1.requiredSchema.fieldNames === Seq("i"))
+
+ val q2 = df.select(lit(1))
+ checkAnswer(q2, (0 until 10).map(i => Row(1)))
+ val reader2 = getReader(q2)
+ assert(reader2.requiredSchema.isEmpty)
+
+ // 'j === 1 can't be pushed down, but we should still be able do column pruning
+ val q3 = df.filter('j === -1).select('j * 2)
+ checkAnswer(q3, Row(-2))
+ val reader3 = getReader(q3)
+ assert(reader3.filters.isEmpty)
+ assert(reader3.requiredSchema.fieldNames === Seq("j"))
+
+ // column pruning should work with other operators.
+ val q4 = df.sort('i).limit(1).select('i + 1)
+ checkAnswer(q4, Row(1))
+ val reader4 = getReader(q4)
+ assert(reader4.requiredSchema.fieldNames === Seq("i"))
+ }
}
class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
@@ -270,8 +363,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
}
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
- this.filters = filters
- Array.empty
+ val (supported, unsupported) = filters.partition {
+ case GreaterThan("i", _: Int) => true
+ case _ => false
+ }
+ this.filters = supported
+ unsupported
}
override def pushedFilters(): Array[Filter] = filters
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org