You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/05/27 12:40:31 UTC

[spark] branch master updated: [SPARK-27803][SQL][PYTHON] Fix column pruning for Python UDF

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6506616  [SPARK-27803][SQL][PYTHON] Fix column pruning for Python UDF
6506616 is described below

commit 6506616b978066a078ad866b3251cdf8fb95af42
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Mon May 27 21:39:59 2019 +0900

    [SPARK-27803][SQL][PYTHON] Fix column pruning for Python UDF
    
    ## What changes were proposed in this pull request?
    
    In https://github.com/apache/spark/pull/22104 , we create the python-eval nodes at the end of the optimization phase, which causes a problem.
    
    After the main optimization batch, Filter and Project nodes are usually pushed to the bottom, near the scan node. However, if we extract Python UDFs from Filter/Project, and create a python-eval node under Filter/Project, it will break column pruning/filter pushdown of the scan node.
    
    There are some hacks in the `ExtractPythonUDFs` rule, to duplicate the column pruning and filter pushdown logic. However, it has some bugs as demonstrated in the new test case(only column pruning is broken). This PR removes the hacks and re-apply the column pruning and filter pushdown rules explicitly.
    
    **Before:**
    
    ```
    ...
    == Analyzed Logical Plan ==
    a: bigint
    Project [a#168L]
    +- Filter dummyUDF(a#168L)
       +- Relation[a#168L,b#169L] parquet
    
    == Optimized Logical Plan ==
    Project [a#168L]
    +- Project [a#168L, b#169L]
       +- Filter pythonUDF0#174: boolean
          +- BatchEvalPython [dummyUDF(a#168L)], [a#168L, b#169L, pythonUDF0#174]
             +- Relation[a#168L,b#169L] parquet
    
    == Physical Plan ==
    *(2) Project [a#168L]
    +- *(2) Project [a#168L, b#169L]
       +- *(2) Filter pythonUDF0#174: boolean
          +- BatchEvalPython [dummyUDF(a#168L)], [a#168L, b#169L, pythonUDF0#174]
             +- *(1) FileScan parquet [a#168L,b#169L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/private/var/folders/_1/bzcp960d0hlb988k90654z2w0000gp/T/spark-798bae3c-a2..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<a:bigint,b:bigint>
    ```
    
    **After:**
    
    ```
    ...
    == Analyzed Logical Plan ==
    a: bigint
    Project [a#168L]
    +- Filter dummyUDF(a#168L)
       +- Relation[a#168L,b#169L] parquet
    
    == Optimized Logical Plan ==
    Project [a#168L]
    +- Filter pythonUDF0#174: boolean
       +- BatchEvalPython [dummyUDF(a#168L)], [pythonUDF0#174]
          +- Project [a#168L]
             +- Relation[a#168L,b#169L] parquet
    
    == Physical Plan ==
    *(2) Project [a#168L]
    +- *(2) Filter pythonUDF0#174: boolean
       +- BatchEvalPython [dummyUDF(a#168L)], [pythonUDF0#174]
          +- *(1) FileScan parquet [a#168L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/private/var/folders/_1/bzcp960d0hlb988k90654z2w0000gp/T/spark-9500cafb-78..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<a:bigint>
    ```
    
    ## How was this patch tested?
    
    new test
    
    Closes #24675 from cloud-fan/python.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: HyukjinKwon <gu...@apache.org>
---
 .../spark/sql/catalyst/expressions/PythonUDF.scala |  3 +-
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  2 ++
 .../plans/logical/pythonLogicalOperators.scala     | 29 +++++++++++++++-
 .../spark/sql/execution/SparkOptimizer.scala       | 15 +++++---
 .../sql/execution/python/ArrowEvalPythonExec.scala | 15 ++------
 .../sql/execution/python/BatchEvalPythonExec.scala | 15 ++------
 .../sql/execution/python/EvalPythonExec.scala      |  6 ++--
 .../sql/execution/python/ExtractPythonUDFs.scala   | 38 +++-----------------
 .../execution/python/ExtractPythonUDFsSuite.scala  | 40 ++++++++++++++++++++--
 9 files changed, 92 insertions(+), 71 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index 6530b17..2d82355 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -45,7 +45,8 @@ object PythonUDF {
 }
 
 /**
- * A serialized version of a Python lambda function.
+ * A serialized version of a Python lambda function. This is a special expression, which needs a
+ * dedicated physical operator to execute it, and thus can't be pushed down to data sources.
  */
 case class PythonUDF(
     name: String,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f32f2c7..5b59ac7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1143,6 +1143,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
     case _: Repartition => true
     case _: ScriptTransformation => true
     case _: Sort => true
+    case _: BatchEvalPython => true
+    case _: ArrowEvalPython => true
     case _ => false
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index 254687e..2df30a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF}
 
 /**
  * FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame.
@@ -38,3 +38,30 @@ case class FlatMapGroupsInPandas(
    */
   override val producedAttributes = AttributeSet(output)
 }
+
+trait BaseEvalPython extends UnaryNode {
+
+  def udfs: Seq[PythonUDF]
+
+  def resultAttrs: Seq[Attribute]
+
+  override def output: Seq[Attribute] = child.output ++ resultAttrs
+
+  override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
+}
+
+/**
+ * A logical plan that evaluates a [[PythonUDF]]
+ */
+case class BatchEvalPython(
+    udfs: Seq[PythonUDF],
+    resultAttrs: Seq[Attribute],
+    child: LogicalPlan) extends BaseEvalPython
+
+/**
+ * A logical plan that evaluates a [[PythonUDF]] with Apache Arrow.
+ */
+case class ArrowEvalPython(
+    udfs: Seq[PythonUDF],
+    resultAttrs: Seq[Attribute],
+    child: LogicalPlan) extends BaseEvalPython
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 31540e8..c35e5de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.ExperimentalMethods
 import org.apache.spark.sql.catalyst.catalog.SessionCatalog
-import org.apache.spark.sql.catalyst.optimizer.Optimizer
+import org.apache.spark.sql.catalyst.optimizer.{ColumnPruning, Optimizer, PushDownPredicate, RemoveNoopOperators}
 import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
 import org.apache.spark.sql.execution.datasources.SchemaPruning
 import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
@@ -32,14 +32,21 @@ class SparkOptimizer(
   override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
     Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
     Batch("Extract Python UDFs", Once,
-      Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+
+      ExtractPythonUDFFromAggregate,
+      ExtractPythonUDFs,
+      // The eval-python node may be between Project/Filter and the scan node, which breaks
+      // column pruning and filter push-down. Here we rerun the related optimizer rules.
+      ColumnPruning,
+      PushDownPredicate,
+      RemoveNoopOperators) :+
     Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
     Batch("Schema Pruning", Once, SchemaPruning)) ++
     postHocOptimizationBatches :+
     Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
 
-  override def nonExcludableRules: Seq[String] =
-    super.nonExcludableRules :+ ExtractPythonUDFFromAggregate.ruleName
+  override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+
+    ExtractPythonUDFFromAggregate.ruleName :+
+    ExtractPythonUDFs.ruleName
 
   /**
    * Optimization batches that are executed before the regular optimization batches (also before
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 61b167f..c114b1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -23,7 +23,6 @@ import org.apache.spark.TaskContext
 import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.arrow.ArrowUtils
 import org.apache.spark.sql.types.StructType
@@ -58,20 +57,10 @@ private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int)
 }
 
 /**
- * A logical plan that evaluates a [[PythonUDF]].
- */
-case class ArrowEvalPython(
-    udfs: Seq[PythonUDF],
-    output: Seq[Attribute],
-    child: LogicalPlan) extends UnaryNode {
-  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
-}
-
-/**
  * A physical plan that evaluates a [[PythonUDF]].
  */
-case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
-  extends EvalPythonExec(udfs, output, child) {
+case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan)
+  extends EvalPythonExec(udfs, resultAttrs, child) {
 
   private val batchSize = conf.arrowMaxRecordsPerBatch
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index eff709e..4f35278 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -25,25 +25,14 @@ import org.apache.spark.TaskContext
 import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.types.{StructField, StructType}
 
 /**
- * A logical plan that evaluates a [[PythonUDF]]
- */
-case class BatchEvalPython(
-    udfs: Seq[PythonUDF],
-    output: Seq[Attribute],
-    child: LogicalPlan) extends UnaryNode {
-  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
-}
-
-/**
  * A physical plan that evaluates a [[PythonUDF]]
  */
-case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
-  extends EvalPythonExec(udfs, output, child) {
+case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan)
+  extends EvalPythonExec(udfs, resultAttrs, child) {
 
   protected override def evaluate(
       funcs: Seq[ChainedPythonFunctions],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
index 67dcdd3..10e1d5e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
@@ -57,10 +57,12 @@ import org.apache.spark.util.Utils
  * there should be always some rows buffered in the socket or Python process, so the pulling from
  * RowQueue ALWAYS happened after pushing into it.
  */
-abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
+abstract class EvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan)
   extends UnaryExecNode {
 
-  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+  override def output: Seq[Attribute] = child.output ++ resultAttrs
+
+  override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
 
   private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
     udf.children match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 380c31b..7f59d74 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -158,21 +158,9 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
       // If there aren't any, we are done.
       plan
     } else {
-      val inputsForPlan = plan.references ++ plan.outputSet
-      val prunedChildren = plan.children.map { child =>
-        val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq
-        if (allNeededOutput.length != child.output.length) {
-          Project(allNeededOutput, child)
-        } else {
-          child
-        }
-      }
-      val planWithNewChildren = plan.withNewChildren(prunedChildren)
-
       val attributeMap = mutable.HashMap[PythonUDF, Expression]()
-      val splitFilter = trySplitFilter(planWithNewChildren)
       // Rewrite the child that has the input required for the UDF
-      val newChildren = splitFilter.children.map { child =>
+      val newChildren = plan.children.map { child =>
         // Pick the UDF we are going to evaluate
         val validUdfs = udfs.filter { udf =>
           // Check to make sure that the UDF can be evaluated with only the input of this child.
@@ -191,9 +179,9 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
             _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
           ) match {
             case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
-              ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child)
+              ArrowEvalPython(vectorizedUdfs, resultAttrs, child)
             case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
-              BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child)
+              BatchEvalPython(plainUdfs, resultAttrs, child)
             case _ =>
               throw new AnalysisException(
                 "Expected either Scalar Pandas UDFs or Batched UDFs but got both")
@@ -211,7 +199,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
         sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
       }
 
-      val rewritten = splitFilter.withNewChildren(newChildren).transformExpressions {
+      val rewritten = plan.withNewChildren(newChildren).transformExpressions {
         case p: PythonUDF if attributeMap.contains(p) =>
           attributeMap(p)
       }
@@ -226,22 +214,4 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
       }
     }
   }
-
-  // Split the original FilterExec to two FilterExecs. Only push down the first few predicates
-  // that are all deterministic.
-  private def trySplitFilter(plan: LogicalPlan): LogicalPlan = {
-    plan match {
-      case filter: Filter =>
-        val (candidates, nonDeterministic) =
-          splitConjunctivePredicates(filter.condition).partition(_.deterministic)
-        val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
-        if (pushDown.nonEmpty) {
-          val newChild = Filter(pushDown.reduceLeft(And), filter.child)
-          Filter((rest ++ nonDeterministic).reduceLeft(And), newChild)
-        } else {
-          filter
-        }
-      case o => o
-    }
-  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
index 76b609d..60cb40f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
@@ -17,13 +17,12 @@
 
 package org.apache.spark.sql.execution.python
 
-import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, SparkPlanTest}
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.test.SharedSQLContext
 
 class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext {
-  import testImplicits.newProductEncoder
-  import testImplicits.localSeqToDatasetHolder
+  import testImplicits._
 
   val batchedPythonUDF = new MyDummyPythonUDF
   val scalarPandasUDF = new MyDummyScalarPandasUDF
@@ -88,5 +87,40 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext {
     assert(pythonEvalNodes.size == 2)
     assert(arrowEvalNodes.size == 2)
   }
+
+  test("Python UDF should not break column pruning/filter pushdown") {
+    withTempPath { f =>
+      spark.range(10).select($"id".as("a"), $"id".as("b"))
+        .write.parquet(f.getCanonicalPath)
+      val df = spark.read.parquet(f.getCanonicalPath)
+
+      withClue("column pruning") {
+        val query = df.filter(batchedPythonUDF($"a")).select($"a")
+
+        val pythonEvalNodes = collectBatchExec(query.queryExecution.executedPlan)
+        assert(pythonEvalNodes.length == 1)
+
+        val scanNodes = query.queryExecution.executedPlan.collect {
+          case scan: FileSourceScanExec => scan
+        }
+        assert(scanNodes.length == 1)
+        assert(scanNodes.head.output.map(_.name) == Seq("a"))
+      }
+
+      withClue("filter pushdown") {
+        val query = df.filter($"a" > 1 && batchedPythonUDF($"a"))
+        val pythonEvalNodes = collectBatchExec(query.queryExecution.executedPlan)
+        assert(pythonEvalNodes.length == 1)
+
+        val scanNodes = query.queryExecution.executedPlan.collect {
+          case scan: FileSourceScanExec => scan
+        }
+        assert(scanNodes.length == 1)
+        // 'a is not null and 'a > 1
+        assert(scanNodes.head.dataFilters.length == 2)
+        assert(scanNodes.head.dataFilters.flatMap(_.references.map(_.name)).distinct == Seq("a"))
+      }
+    }
+  }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org