You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2020/04/06 00:37:49 UTC

[spark] branch branch-3.0 updated: [SPARK-30921][PYSPARK] Predicates on python udf should not be pushdown through Aggregate

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 2a5f0ab  [SPARK-30921][PYSPARK] Predicates on python udf should not be pushdown through Aggregate
2a5f0ab is described below

commit 2a5f0aba73fa7a933b605a4108001dca51a91eb5
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Apr 6 09:36:20 2020 +0900

    [SPARK-30921][PYSPARK] Predicates on python udf should not be pushdown through Aggregate
    
    ### What changes were proposed in this pull request?
    
    This patch proposed to skip predicates on PythonUDFs to be pushdown through Aggregate.
    
    ### Why are the changes needed?
    
    The predicates on PythonUDFs cannot be pushdown through Aggregate. Pushed down predicates cannot be evaluate because PythonUDFs cannot be evaluated on Filter and cause error like:
    
    ```
    Caused by: java.lang.UnsupportedOperationException: Cannot generate code for expression: mean(input[1, struct<bar:bigint>, true].bar)
            at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode(Expression.scala:304)
            at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode$(Expression.scala:303)
            at org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:52)
            at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:146)
            at scala.Option.getOrElse(Option.scala:189)
            at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:141)
            at org.apache.spark.sql.catalyst.expressions.CastBase.doGenCode(Cast.scala:821)
            at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:146)
            at scala.Option.getOrElse(Option.scala:189)
    ```
    
    ### Does this PR introduce any user-facing change?
    
    Yes. Previously the predicates on PythonUDFs will be pushdown through Aggregate can cause error. After this change, the query can work.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #28089 from viirya/SPARK-30921.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: HyukjinKwon <gu...@apache.org>
    (cherry picked from commit 1f0287148977adb416001cb0988e919a2698c8e0)
    Signed-off-by: HyukjinKwon <gu...@apache.org>
---
 python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py | 17 +++++++++++++++++
 .../apache/spark/sql/catalyst/optimizer/Optimizer.scala |  5 +++--
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
index 2167978..224c8ce 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
@@ -491,6 +491,23 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
             agg2 = self.spark.sql("select max_udf(id) from table")
             assert_frame_equal(agg1.toPandas(), agg2.toPandas())
 
+    def test_no_predicate_pushdown_through(self):
+        # SPARK-30921: We should not pushdown predicates of PythonUDFs through Aggregate.
+        import numpy as np
+
+        @pandas_udf('float', PandasUDFType.GROUPED_AGG)
+        def mean(x):
+            return np.mean(x)
+
+        df = self.spark.createDataFrame([
+            Row(id=1, foo=42), Row(id=2, foo=1), Row(id=2, foo=2)
+        ])
+
+        agg = df.groupBy('id').agg(mean('foo').alias("mean"))
+        filtered = agg.filter(agg['mean'] > 40.0)
+
+        assert(filtered.collect()[0]["mean"] == 42.0)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.test_pandas_udf_grouped_agg import *
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 30ad6bfe..d93c4a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1204,9 +1204,10 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe
 
   def getAliasMap(plan: Aggregate): AttributeMap[Expression] = {
     // Find all the aliased expressions in the aggregate list that don't include any actual
-    // AggregateExpression, and create a map from the alias to the expression
+    // AggregateExpression or PythonUDF, and create a map from the alias to the expression
     val aliasMap = plan.aggregateExpressions.collect {
-      case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
+      case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
+          PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
         (a.toAttribute, a.child)
     }
     AttributeMap(aliasMap)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org