You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/02/02 04:44:50 UTC

spark git commit: [SPARK-23301][SQL] data source column pruning should work for arbitrary expressions

Repository: spark
Updated Branches:
  refs/heads/master b3a04283f -> 19c7c7ebd


[SPARK-23301][SQL] data source column pruning should work for arbitrary expressions

## What changes were proposed in this pull request?

This PR fixes a mistake in the `PushDownOperatorsToDataSource` rule, the column pruning logic is incorrect about `Project`.

## How was this patch tested?

a new test case for column pruning with arbitrary expressions, and improve the existing tests to make sure the `PushDownOperatorsToDataSource` really works.

Author: Wenchen Fan <we...@databricks.com>

Closes #20476 from cloud-fan/push-down.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/19c7c7eb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/19c7c7eb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/19c7c7eb

Branch: refs/heads/master
Commit: 19c7c7ebdef6c1c7a02ebac9af6a24f521b52c37
Parents: b3a0428
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Feb 1 20:44:46 2018 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Feb 1 20:44:46 2018 -0800

----------------------------------------------------------------------
 .../v2/PushDownOperatorsToDataSource.scala      |  53 +++++----
 .../sources/v2/JavaAdvancedDataSourceV2.java    |  29 ++++-
 .../sql/sources/v2/DataSourceV2Suite.scala      | 113 +++++++++++++++++--
 3 files changed, 155 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/19c7c7eb/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
index df034ad..566a483 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper}
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper}
 import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -81,35 +81,34 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel
 
     // TODO: add more push down rules.
 
-    // TODO: nested fields pruning
-    def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = {
-      plan match {
-        case Project(projectList, child) =>
-          val required = projectList.filter(requiredByParent.contains).flatMap(_.references)
-          pushDownRequiredColumns(child, required)
-
-        case Filter(condition, child) =>
-          val required = requiredByParent ++ condition.references
-          pushDownRequiredColumns(child, required)
-
-        case DataSourceV2Relation(fullOutput, reader) => reader match {
-          case r: SupportsPushDownRequiredColumns =>
-            // Match original case of attributes.
-            val attrMap = AttributeMap(fullOutput.zip(fullOutput))
-            val requiredColumns = requiredByParent.map(attrMap)
-            r.pruneColumns(requiredColumns.toStructType)
-          case _ =>
-        }
+    pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
+    // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
+    RemoveRedundantProject(filterPushed)
+  }
+
+  // TODO: nested fields pruning
+  private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = {
+    plan match {
+      case Project(projectList, child) =>
+        val required = projectList.flatMap(_.references)
+        pushDownRequiredColumns(child, AttributeSet(required))
+
+      case Filter(condition, child) =>
+        val required = requiredByParent ++ condition.references
+        pushDownRequiredColumns(child, required)
 
-        // TODO: there may be more operators can be used to calculate required columns, we can add
-        // more and more in the future.
-        case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output))
+      case relation: DataSourceV2Relation => relation.reader match {
+        case reader: SupportsPushDownRequiredColumns =>
+          val requiredColumns = relation.output.filter(requiredByParent.contains)
+          reader.pruneColumns(requiredColumns.toStructType)
+
+        case _ =>
       }
-    }
 
-    pushDownRequiredColumns(filterPushed, filterPushed.output)
-    // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
-    RemoveRedundantProject(filterPushed)
+      // TODO: there may be more operators that can be used to calculate the required columns. We
+      // can add more and more in the future.
+      case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet))
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/19c7c7eb/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
index d421f7d..172e5d5 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
@@ -32,11 +32,12 @@ import org.apache.spark.sql.types.StructType;
 
 public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport {
 
-  class Reader implements DataSourceReader, SupportsPushDownRequiredColumns,
+  public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns,
       SupportsPushDownFilters {
 
-    private StructType requiredSchema = new StructType().add("i", "int").add("j", "int");
-    private Filter[] filters = new Filter[0];
+    // Exposed for testing.
+    public StructType requiredSchema = new StructType().add("i", "int").add("j", "int");
+    public Filter[] filters = new Filter[0];
 
     @Override
     public StructType readSchema() {
@@ -50,8 +51,26 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport {
 
     @Override
     public Filter[] pushFilters(Filter[] filters) {
-      this.filters = filters;
-      return new Filter[0];
+      Filter[] supported = Arrays.stream(filters).filter(f -> {
+        if (f instanceof GreaterThan) {
+          GreaterThan gt = (GreaterThan) f;
+          return gt.attribute().equals("i") && gt.value() instanceof Integer;
+        } else {
+          return false;
+        }
+      }).toArray(Filter[]::new);
+
+      Filter[] unsupported = Arrays.stream(filters).filter(f -> {
+        if (f instanceof GreaterThan) {
+          GreaterThan gt = (GreaterThan) f;
+          return !gt.attribute().equals("i") || !(gt.value() instanceof Integer);
+        } else {
+          return true;
+        }
+      }).toArray(Filter[]::new);
+
+      this.filters = supported;
+      return unsupported;
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/spark/blob/19c7c7eb/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index 23147ff..eccd454 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -21,11 +21,13 @@ import java.util.{ArrayList, List => JList}
 
 import test.org.apache.spark.sql.sources.v2._
 
-import org.apache.spark.{SparkConf, SparkException}
-import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.sources.{Filter, GreaterThan}
 import org.apache.spark.sql.sources.v2.reader._
 import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning}
@@ -48,14 +50,72 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
   }
 
   test("advanced implementation") {
+    def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = {
+      query.queryExecution.executedPlan.collect {
+        case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader]
+      }.head
+    }
+
+    def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = {
+      query.queryExecution.executedPlan.collect {
+        case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader]
+      }.head
+    }
+
     Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls =>
       withClue(cls.getName) {
         val df = spark.read.format(cls.getName).load()
         checkAnswer(df, (0 until 10).map(i => Row(i, -i)))
-        checkAnswer(df.select('j), (0 until 10).map(i => Row(-i)))
-        checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i)))
-        checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i)))
-        checkAnswer(df.select('i).filter('i > 10), Nil)
+
+        val q1 = df.select('j)
+        checkAnswer(q1, (0 until 10).map(i => Row(-i)))
+        if (cls == classOf[AdvancedDataSourceV2]) {
+          val reader = getReader(q1)
+          assert(reader.filters.isEmpty)
+          assert(reader.requiredSchema.fieldNames === Seq("j"))
+        } else {
+          val reader = getJavaReader(q1)
+          assert(reader.filters.isEmpty)
+          assert(reader.requiredSchema.fieldNames === Seq("j"))
+        }
+
+        val q2 = df.filter('i > 3)
+        checkAnswer(q2, (4 until 10).map(i => Row(i, -i)))
+        if (cls == classOf[AdvancedDataSourceV2]) {
+          val reader = getReader(q2)
+          assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+          assert(reader.requiredSchema.fieldNames === Seq("i", "j"))
+        } else {
+          val reader = getJavaReader(q2)
+          assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+          assert(reader.requiredSchema.fieldNames === Seq("i", "j"))
+        }
+
+        val q3 = df.select('i).filter('i > 6)
+        checkAnswer(q3, (7 until 10).map(i => Row(i)))
+        if (cls == classOf[AdvancedDataSourceV2]) {
+          val reader = getReader(q3)
+          assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+          assert(reader.requiredSchema.fieldNames === Seq("i"))
+        } else {
+          val reader = getJavaReader(q3)
+          assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+          assert(reader.requiredSchema.fieldNames === Seq("i"))
+        }
+
+        val q4 = df.select('j).filter('j < -10)
+        checkAnswer(q4, Nil)
+        if (cls == classOf[AdvancedDataSourceV2]) {
+          val reader = getReader(q4)
+          // 'j < 10 is not supported by the testing data source.
+          assert(reader.filters.isEmpty)
+          assert(reader.requiredSchema.fieldNames === Seq("j"))
+        } else {
+          val reader = getJavaReader(q4)
+          // 'j < 10 is not supported by the testing data source.
+          assert(reader.filters.isEmpty)
+          assert(reader.requiredSchema.fieldNames === Seq("j"))
+        }
       }
     }
   }
@@ -223,6 +283,39 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
     val df2 = df.select(($"i" + 1).as("k"), $"j")
     checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1)))
   }
+
+  test("SPARK-23301: column pruning with arbitrary expressions") {
+    def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = {
+      query.queryExecution.executedPlan.collect {
+        case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader]
+      }.head
+    }
+
+    val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
+
+    val q1 = df.select('i + 1)
+    checkAnswer(q1, (1 until 11).map(i => Row(i)))
+    val reader1 = getReader(q1)
+    assert(reader1.requiredSchema.fieldNames === Seq("i"))
+
+    val q2 = df.select(lit(1))
+    checkAnswer(q2, (0 until 10).map(i => Row(1)))
+    val reader2 = getReader(q2)
+    assert(reader2.requiredSchema.isEmpty)
+
+    // 'j === 1 can't be pushed down, but we should still be able do column pruning
+    val q3 = df.filter('j === -1).select('j * 2)
+    checkAnswer(q3, Row(-2))
+    val reader3 = getReader(q3)
+    assert(reader3.filters.isEmpty)
+    assert(reader3.requiredSchema.fieldNames === Seq("j"))
+
+    // column pruning should work with other operators.
+    val q4 = df.sort('i).limit(1).select('i + 1)
+    checkAnswer(q4, Row(1))
+    val reader4 = getReader(q4)
+    assert(reader4.requiredSchema.fieldNames === Seq("i"))
+  }
 }
 
 class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
@@ -270,8 +363,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
     }
 
     override def pushFilters(filters: Array[Filter]): Array[Filter] = {
-      this.filters = filters
-      Array.empty
+      val (supported, unsupported) = filters.partition {
+        case GreaterThan("i", _: Int) => true
+        case _ => false
+      }
+      this.filters = supported
+      unsupported
     }
 
     override def pushedFilters(): Array[Filter] = filters


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org