You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2020/03/20 03:17:20 UTC

[spark] branch branch-2.4 updated: [SPARK-26293][SQL][2.4] Cast exception when having python udf in subquery

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 244405f  [SPARK-26293][SQL][2.4] Cast exception when having python udf in subquery
244405f is described below

commit 244405fe57d7737d81c34ba9e8917df6285889eb
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Fri Mar 20 12:16:38 2020 +0900

    [SPARK-26293][SQL][2.4] Cast exception when having python udf in subquery
    
    ## What changes were proposed in this pull request?
    
    This PR backports https://github.com/apache/spark/pull/23248 which seems mistakenly not backported.
    
    This is a regression introduced by https://github.com/apache/spark/pull/22104 at Spark 2.4.0.
    
    When we have Python UDF in subquery, we will hit an exception
    ```
    Caused by: java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.AttributeReference cannot be cast to org.apache.spark.sql.catalyst.expressions.PythonUDF
    	at scala.collection.immutable.Stream.map(Stream.scala:414)
    	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:98)
    	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:815)
    ...
    ```
    
    https://github.com/apache/spark/pull/22104 turned `ExtractPythonUDFs` from a physical rule to optimizer rule. However, there is a difference between a physical rule and optimizer rule. A physical rule always runs once, an optimizer rule may be applied twice on a query tree even the rule is located in a batch that only runs once.
    
    For a subquery, the `OptimizeSubqueries` rule will execute the entire optimizer on the query plan inside subquery. Later on subquery will be turned to joins, and the optimizer rules will be applied to it again.
    
    Unfortunately, the `ExtractPythonUDFs` rule is not idempotent. When it's applied twice on a query plan inside subquery, it will produce a malformed plan. It extracts Python UDF from Python exec plans.
    
    This PR proposes 2 changes to be double safe:
    1. `ExtractPythonUDFs` should skip python exec plans, to make the rule idempotent
    2. `ExtractPythonUDFs` should skip subquery
    
    ## How was this patch tested?
    
    a new test.
    
    Closes #27960 from HyukjinKwon/backport-SPARK-26293.
    
    Lead-authored-by: Wenchen Fan <we...@databricks.com>
    Co-authored-by: HyukjinKwon <gu...@apache.org>
    Signed-off-by: HyukjinKwon <gu...@apache.org>
---
 python/pyspark/sql/tests.py                        | 54 +++++++++-------------
 .../sql/execution/python/ArrowEvalPythonExec.scala |  8 +++-
 .../sql/execution/python/BatchEvalPythonExec.scala |  8 +++-
 .../sql/execution/python/ExtractPythonUDFs.scala   | 18 ++++++--
 4 files changed, 49 insertions(+), 39 deletions(-)

diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 949ce3b..0284267 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -88,7 +88,7 @@ from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_
 from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
 from pyspark.sql.types import _merge_type
 from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests
-from pyspark.sql.functions import UserDefinedFunction, sha2, lit, input_file_name
+from pyspark.sql.functions import UserDefinedFunction, sha2, lit, input_file_name, udf
 from pyspark.sql.window import Window
 from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
 
@@ -457,7 +457,6 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_nondeterministic_udf(self):
         # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
-        from pyspark.sql.functions import udf
         import random
         udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
         self.assertEqual(udf_random_col.deterministic, False)
@@ -468,7 +467,6 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_nondeterministic_udf2(self):
         import random
-        from pyspark.sql.functions import udf
         random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
         self.assertEqual(random_udf.deterministic, False)
         random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
@@ -487,7 +485,6 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_nondeterministic_udf3(self):
         # regression test for SPARK-23233
-        from pyspark.sql.functions import udf
         f = udf(lambda x: x)
         # Here we cache the JVM UDF instance.
         self.spark.range(1).select(f("id"))
@@ -499,7 +496,7 @@ class SQLTests(ReusedSQLTestCase):
         self.assertFalse(deterministic)
 
     def test_nondeterministic_udf_in_aggregate(self):
-        from pyspark.sql.functions import udf, sum
+        from pyspark.sql.functions import sum
         import random
         udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
         df = self.spark.range(10)
@@ -536,7 +533,6 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(tuple(row), (6, 5))
 
     def test_udf_in_filter_on_top_of_outer_join(self):
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1)])
         right = self.spark.createDataFrame([Row(a=1)])
         df = left.join(right, on='a', how='left_outer')
@@ -545,7 +541,6 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_udf_in_filter_on_top_of_join(self):
         # regression test for SPARK-18589
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1)])
         right = self.spark.createDataFrame([Row(b=1)])
         f = udf(lambda a, b: a == b, BooleanType())
@@ -554,7 +549,6 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_udf_in_join_condition(self):
         # regression test for SPARK-25314
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1)])
         right = self.spark.createDataFrame([Row(b=1)])
         f = udf(lambda a, b: a == b, BooleanType())
@@ -566,7 +560,7 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_udf_in_left_outer_join_condition(self):
         # regression test for SPARK-26147
-        from pyspark.sql.functions import udf, col
+        from pyspark.sql.functions import col
         left = self.spark.createDataFrame([Row(a=1)])
         right = self.spark.createDataFrame([Row(b=1)])
         f = udf(lambda a: str(a), StringType())
@@ -579,7 +573,6 @@ class SQLTests(ReusedSQLTestCase):
     def test_udf_and_common_filter_in_join_condition(self):
         # regression test for SPARK-25314
         # test the complex scenario with both udf and common filter
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
         right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
         f = udf(lambda a, b: a == b, BooleanType())
@@ -590,7 +583,6 @@ class SQLTests(ReusedSQLTestCase):
     def test_udf_not_supported_in_join_condition(self):
         # regression test for SPARK-25314
         # test python udf is not supported in join type except inner join.
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
         right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
         f = udf(lambda a, b: a == b, BooleanType())
@@ -632,7 +624,7 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_udf_with_filter_function(self):
         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
-        from pyspark.sql.functions import udf, col
+        from pyspark.sql.functions import col
         from pyspark.sql.types import BooleanType
 
         my_filter = udf(lambda a: a < 2, BooleanType())
@@ -641,7 +633,7 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_udf_with_aggregate_function(self):
         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
-        from pyspark.sql.functions import udf, col, sum
+        from pyspark.sql.functions import col, sum
         from pyspark.sql.types import BooleanType
 
         my_filter = udf(lambda a: a == 1, BooleanType())
@@ -657,7 +649,7 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
 
     def test_udf_in_generate(self):
-        from pyspark.sql.functions import udf, explode
+        from pyspark.sql.functions import explode
         df = self.spark.range(5)
         f = udf(lambda x: list(range(x)), ArrayType(LongType()))
         row = df.select(explode(f(*df))).groupBy().sum().first()
@@ -684,7 +676,6 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(res[3][1], 1)
 
     def test_udf_with_order_by_and_limit(self):
-        from pyspark.sql.functions import udf
         my_copy = udf(lambda x: x, IntegerType())
         df = self.spark.range(10).orderBy("id")
         res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
@@ -803,14 +794,14 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(2, df.count())
 
     def test_udf_with_input_file_name(self):
-        from pyspark.sql.functions import udf, input_file_name
+        from pyspark.sql.functions import input_file_name
         sourceFile = udf(lambda path: path, StringType())
         filePath = "python/test_support/sql/people1.json"
         row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
         self.assertTrue(row[0].find("people1.json") != -1)
 
     def test_udf_with_input_file_name_for_hadooprdd(self):
-        from pyspark.sql.functions import udf, input_file_name
+        from pyspark.sql.functions import input_file_name
 
         def filename(path):
             return path
@@ -859,9 +850,6 @@ class SQLTests(ReusedSQLTestCase):
         # This is separate of  UDFInitializationTests
         # to avoid context initialization
         # when udf is called
-
-        from pyspark.sql.functions import UserDefinedFunction
-
         f = UserDefinedFunction(lambda x: x, StringType())
 
         self.assertIsNone(
@@ -877,8 +865,6 @@ class SQLTests(ReusedSQLTestCase):
         )
 
     def test_udf_with_string_return_type(self):
-        from pyspark.sql.functions import UserDefinedFunction
-
         add_one = UserDefinedFunction(lambda x: x + 1, "integer")
         make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
         make_array = UserDefinedFunction(
@@ -892,13 +878,11 @@ class SQLTests(ReusedSQLTestCase):
         self.assertTupleEqual(expected, actual)
 
     def test_udf_shouldnt_accept_noncallable_object(self):
-        from pyspark.sql.functions import UserDefinedFunction
-
         non_callable = None
         self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
 
     def test_udf_with_decorator(self):
-        from pyspark.sql.functions import lit, udf
+        from pyspark.sql.functions import lit
         from pyspark.sql.types import IntegerType, DoubleType
 
         @udf(IntegerType())
@@ -955,7 +939,6 @@ class SQLTests(ReusedSQLTestCase):
         )
 
     def test_udf_wrapper(self):
-        from pyspark.sql.functions import udf
         from pyspark.sql.types import IntegerType
 
         def f(x):
@@ -991,7 +974,7 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(return_type, f_.returnType)
 
     def test_validate_column_types(self):
-        from pyspark.sql.functions import udf, to_json
+        from pyspark.sql.functions import to_json
         from pyspark.sql.column import _to_java_column
 
         self.assertTrue("Column" in _to_java_column("a").getClass().toString())
@@ -3459,7 +3442,7 @@ class SQLTests(ReusedSQLTestCase):
     # SPARK-24721
     @unittest.skipIf(not _test_compiled, _test_not_compiled_message)
     def test_datasource_with_udf(self):
-        from pyspark.sql.functions import udf, lit, col
+        from pyspark.sql.functions import lit, col
 
         path = tempfile.mkdtemp()
         shutil.rmtree(path)
@@ -3571,8 +3554,6 @@ class SQLTests(ReusedSQLTestCase):
 
     # SPARK-25591
     def test_same_accumulator_in_udfs(self):
-        from pyspark.sql.functions import udf
-
         data_schema = StructType([StructField("a", IntegerType(), True),
                                   StructField("b", IntegerType(), True)])
         data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
@@ -3594,6 +3575,17 @@ class SQLTests(ReusedSQLTestCase):
         data.collect()
         self.assertEqual(test_accum.value, 101)
 
+    # SPARK-26293
+    def test_udf_in_subquery(self):
+        f = udf(lambda x: x, "long")
+        try:
+            self.spark.range(1).filter(f("id") >= 0).createTempView("v")
+            sql = self.spark.sql
+            result = sql("select i from values(0L) as data(i) where i in (select id from v)")
+            self.assertEqual(result.collect(), [Row(i=0)])
+        finally:
+            self.spark.catalog.dropTempView("v")
+
 
 class HiveSparkSubmitTests(SparkSubmitTests):
 
@@ -3771,8 +3763,6 @@ class UDFInitializationTests(unittest.TestCase):
             SparkContext._active_spark_context.stop()
 
     def test_udf_init_shouldnt_initialize_context(self):
-        from pyspark.sql.functions import UserDefinedFunction
-
         UserDefinedFunction(lambda x: x, StringType())
 
         self.assertIsNone(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 2b87796..a5203da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -60,8 +60,12 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
 /**
  * A logical plan that evaluates a [[PythonUDF]].
  */
-case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
-  extends UnaryNode
+case class ArrowEvalPython(
+    udfs: Seq[PythonUDF],
+    output: Seq[Attribute],
+    child: LogicalPlan) extends UnaryNode {
+  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+}
 
 /**
  * A physical plan that evaluates a [[PythonUDF]].
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index b08b7e6..d3736d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -32,8 +32,12 @@ import org.apache.spark.sql.types.{StructField, StructType}
 /**
  * A logical plan that evaluates a [[PythonUDF]]
  */
-case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
-  extends UnaryNode
+case class BatchEvalPython(
+    udfs: Seq[PythonUDF],
+    output: Seq[Attribute],
+    child: LogicalPlan) extends UnaryNode {
+  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+}
 
 /**
  * A physical plan that evaluates a [[PythonUDF]]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 90b5325..380c31b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -24,7 +24,7 @@ import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 
 
@@ -131,8 +131,20 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
     expressions.flatMap(collectEvaluableUDFs)
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-    case plan: LogicalPlan => extract(plan)
+  def apply(plan: LogicalPlan): LogicalPlan = plan match {
+    // SPARK-26293: A subquery will be rewritten into join later, and will go through this rule
+    // eventually. Here we skip subquery, as Python UDF only needs to be extracted once.
+    case _: Subquery => plan
+
+    case _ => plan transformUp {
+      // A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and
+      // `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't
+      // extract Python UDFs from them.
+      case p: BatchEvalPython => p
+      case p: ArrowEvalPython => p
+
+      case plan: LogicalPlan => extract(plan)
+    }
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org