You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/01/24 11:29:51 UTC

spark git commit: [SPARK-23177][SQL][PYSPARK][BACKPORT-2.3] Extract zero-parameter UDFs from aggregate

Repository: spark
Updated Branches:
  refs/heads/branch-2.3 d656be74b -> 84a189a34


[SPARK-23177][SQL][PYSPARK][BACKPORT-2.3] Extract zero-parameter UDFs from aggregate

## What changes were proposed in this pull request?

We extract Python UDFs in logical aggregate which depends on aggregate expression or grouping key in ExtractPythonUDFFromAggregate rule. But Python UDFs which don't depend on above expressions should also be extracted to avoid the issue reported in the JIRA.

A small code snippet to reproduce that issue looks like:
```python
import pyspark.sql.functions as f

df = spark.createDataFrame([(1,2), (3,4)])
f_udf = f.udf(lambda: str("const_str"))
df2 = df.distinct().withColumn("a", f_udf())
df2.show()
```

Error exception is raised as:
```
: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF0#50
        at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
        at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:91)
        at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:90)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
        at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
        at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:266)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
        at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
        at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
        at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
        at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:256)
        at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:90)
        at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:514)
        at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:513)
```

This exception raises because `HashAggregateExec` tries to bind the aliased Python UDF expression (e.g., `pythonUDF0#50 AS a#44`) to grouping key.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #20379 from viirya/SPARK-23177-backport-2.3.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84a189a3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84a189a3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84a189a3

Branch: refs/heads/branch-2.3
Commit: 84a189a3429c64eeabc7b4e8fc0488ec16742002
Parents: d656be7
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Wed Jan 24 20:29:47 2018 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Wed Jan 24 20:29:47 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                                  | 8 ++++++++
 .../spark/sql/execution/python/ExtractPythonUDFs.scala       | 5 +++--
 2 files changed, 11 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/84a189a3/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4fee2ec..fbb18c4 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1100,6 +1100,14 @@ class SQLTests(ReusedSQLTestCase):
         rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
         self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
 
+    def test_nonparam_udf_with_aggregate(self):
+        import pyspark.sql.functions as f
+
+        df = self.spark.createDataFrame([(1, 2), (1, 2)])
+        f_udf = f.udf(lambda: "const_str")
+        rows = df.distinct().withColumn("a", f_udf()).collect()
+        self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])
+
     def test_infer_schema_with_udt(self):
         from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))

http://git-wip-us.apache.org/repos/asf/spark/blob/84a189a3/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 2f53fe7..e6616d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
 
 /**
  * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
- * grouping key, evaluate them after aggregate.
+ * grouping key, or doesn't depend on any above expressions, evaluate them after aggregate.
  */
 object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
 
@@ -44,7 +44,8 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
 
   private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
     expr.find {
-      e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
+      e => e.isInstanceOf[PythonUDF] &&
+        (e.references.isEmpty || e.find(belongAggregate(_, agg)).isDefined)
     }.isDefined
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org