You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/09/12 23:35:46 UTC

spark git commit: [SPARK-17474] [SQL] fix python udf in TakeOrderedAndProjectExec

Repository: spark
Updated Branches:
  refs/heads/master f9c580f11 -> a91ab705e


[SPARK-17474] [SQL] fix python udf in TakeOrderedAndProjectExec

## What changes were proposed in this pull request?

When there is any Python UDF in the Project between Sort and Limit, it will be collected into TakeOrderedAndProjectExec, ExtractPythonUDFs failed to pull the Python UDFs out because QueryPlan.expressions does not include the expression inside Option[Seq[Expression]].

Ideally, we should fix the `QueryPlan.expressions`, but tried with no luck (it always run into infinite loop). In PR, I changed the TakeOrderedAndProjectExec to no use Option[Seq[Expression]] to workaround it. cc JoshRosen

## How was this patch tested?

Added regression test.

Author: Davies Liu <da...@databricks.com>

Closes #15030 from davies/all_expr.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a91ab705
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a91ab705
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a91ab705

Branch: refs/heads/master
Commit: a91ab705e8c124aa116c3e5b1f3ba88ce832dcde
Parents: f9c580f
Author: Davies Liu <da...@databricks.com>
Authored: Mon Sep 12 16:35:42 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Mon Sep 12 16:35:42 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                             |  8 ++++++++
 .../apache/spark/sql/execution/SparkStrategies.scala    |  8 ++++----
 .../scala/org/apache/spark/sql/execution/limit.scala    | 12 ++++++------
 .../sql/execution/TakeOrderedAndProjectSuite.scala      |  4 ++--
 4 files changed, 20 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a91ab705/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index fd8e9ce..769e454 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -376,6 +376,14 @@ class SQLTests(ReusedPySparkTestCase):
         row = df.select(explode(f(*df))).groupBy().sum().first()
         self.assertEqual(row[0], 10)
 
+    def test_udf_with_order_by_and_limit(self):
+        from pyspark.sql.functions import udf
+        my_copy = udf(lambda x: x, IntegerType())
+        df = self.spark.range(10).orderBy("id")
+        res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
+        res.explain(True)
+        self.assertEqual(res.collect(), [Row(id=0, copy=0)])
+
     def test_basic_functions(self):
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
         df = self.spark.read.json(rdd)

http://git-wip-us.apache.org/repos/asf/spark/blob/a91ab705/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c389593..3441ccf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -66,22 +66,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case logical.ReturnAnswer(rootPlan) => rootPlan match {
         case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
-          execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil
+          execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
         case logical.Limit(
             IntegerLiteral(limit),
             logical.Project(projectList, logical.Sort(order, true, child))) =>
           execution.TakeOrderedAndProjectExec(
-            limit, order, Some(projectList), planLater(child)) :: Nil
+            limit, order, projectList, planLater(child)) :: Nil
         case logical.Limit(IntegerLiteral(limit), child) =>
           execution.CollectLimitExec(limit, planLater(child)) :: Nil
         case other => planLater(other) :: Nil
       }
       case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
-        execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil
+        execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
       case logical.Limit(
           IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) =>
         execution.TakeOrderedAndProjectExec(
-          limit, order, Some(projectList), planLater(child)) :: Nil
+          limit, order, projectList, planLater(child)) :: Nil
       case _ => Nil
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/a91ab705/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 781c016..01fbe5b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -114,11 +114,11 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
 case class TakeOrderedAndProjectExec(
     limit: Int,
     sortOrder: Seq[SortOrder],
-    projectList: Option[Seq[NamedExpression]],
+    projectList: Seq[NamedExpression],
     child: SparkPlan) extends UnaryExecNode {
 
   override def output: Seq[Attribute] = {
-    projectList.map(_.map(_.toAttribute)).getOrElse(child.output)
+    projectList.map(_.toAttribute)
   }
 
   override def outputPartitioning: Partitioning = SinglePartition
@@ -126,8 +126,8 @@ case class TakeOrderedAndProjectExec(
   override def executeCollect(): Array[InternalRow] = {
     val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
     val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
-    if (projectList.isDefined) {
-      val proj = UnsafeProjection.create(projectList.get, child.output)
+    if (projectList != child.output) {
+      val proj = UnsafeProjection.create(projectList, child.output)
       data.map(r => proj(r).copy())
     } else {
       data
@@ -148,8 +148,8 @@ case class TakeOrderedAndProjectExec(
         localTopK, child.output, SinglePartition, serializer))
     shuffled.mapPartitions { iter =>
       val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
-      if (projectList.isDefined) {
-        val proj = UnsafeProjection.create(projectList.get, child.output)
+      if (projectList != child.output) {
+        val proj = UnsafeProjection.create(projectList, child.output)
         topK.map(r => proj(r))
       } else {
         topK

http://git-wip-us.apache.org/repos/asf/spark/blob/a91ab705/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
index 3217e34..7e317a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
@@ -59,7 +59,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
       checkThatPlansAgree(
         generateRandomInputData(),
         input =>
-          noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, None, input)),
+          noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
         input =>
           GlobalLimitExec(limit,
             LocalLimitExec(limit,
@@ -74,7 +74,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
         generateRandomInputData(),
         input =>
           noOpFilter(
-            TakeOrderedAndProjectExec(limit, sortOrder, Some(Seq(input.output.last)), input)),
+            TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
         input =>
           GlobalLimitExec(limit,
             LocalLimitExec(limit,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org