You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2018/12/11 06:21:40 UTC

[GitHub] asfgit closed pull request #23248: [SPARK-26293][SQL] Cast exception when having python udf in subquery

asfgit closed pull request #23248: [SPARK-26293][SQL] Cast exception when having python udf in subquery
URL: https://github.com/apache/spark/pull/23248
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index ed298f724d551..12cf8c7de1dad 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -23,7 +23,7 @@
 
 from pyspark import SparkContext
 from pyspark.sql import SparkSession, Column, Row
-from pyspark.sql.functions import UserDefinedFunction
+from pyspark.sql.functions import UserDefinedFunction, udf
 from pyspark.sql.types import *
 from pyspark.sql.utils import AnalysisException
 from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
@@ -102,7 +102,6 @@ def test_udf_registration_return_type_not_none(self):
 
     def test_nondeterministic_udf(self):
         # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
-        from pyspark.sql.functions import udf
         import random
         udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
         self.assertEqual(udf_random_col.deterministic, False)
@@ -113,7 +112,6 @@ def test_nondeterministic_udf(self):
 
     def test_nondeterministic_udf2(self):
         import random
-        from pyspark.sql.functions import udf
         random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
         self.assertEqual(random_udf.deterministic, False)
         random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
@@ -132,7 +130,6 @@ def test_nondeterministic_udf2(self):
 
     def test_nondeterministic_udf3(self):
         # regression test for SPARK-23233
-        from pyspark.sql.functions import udf
         f = udf(lambda x: x)
         # Here we cache the JVM UDF instance.
         self.spark.range(1).select(f("id"))
@@ -144,7 +141,7 @@ def test_nondeterministic_udf3(self):
         self.assertFalse(deterministic)
 
     def test_nondeterministic_udf_in_aggregate(self):
-        from pyspark.sql.functions import udf, sum
+        from pyspark.sql.functions import sum
         import random
         udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
         df = self.spark.range(10)
@@ -181,7 +178,6 @@ def test_multiple_udfs(self):
         self.assertEqual(tuple(row), (6, 5))
 
     def test_udf_in_filter_on_top_of_outer_join(self):
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1)])
         right = self.spark.createDataFrame([Row(a=1)])
         df = left.join(right, on='a', how='left_outer')
@@ -190,7 +186,6 @@ def test_udf_in_filter_on_top_of_outer_join(self):
 
     def test_udf_in_filter_on_top_of_join(self):
         # regression test for SPARK-18589
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1)])
         right = self.spark.createDataFrame([Row(b=1)])
         f = udf(lambda a, b: a == b, BooleanType())
@@ -199,7 +194,6 @@ def test_udf_in_filter_on_top_of_join(self):
 
     def test_udf_in_join_condition(self):
         # regression test for SPARK-25314
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1)])
         right = self.spark.createDataFrame([Row(b=1)])
         f = udf(lambda a, b: a == b, BooleanType())
@@ -211,7 +205,7 @@ def test_udf_in_join_condition(self):
 
     def test_udf_in_left_outer_join_condition(self):
         # regression test for SPARK-26147
-        from pyspark.sql.functions import udf, col
+        from pyspark.sql.functions import col
         left = self.spark.createDataFrame([Row(a=1)])
         right = self.spark.createDataFrame([Row(b=1)])
         f = udf(lambda a: str(a), StringType())
@@ -223,7 +217,6 @@ def test_udf_in_left_outer_join_condition(self):
 
     def test_udf_in_left_semi_join_condition(self):
         # regression test for SPARK-25314
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
         right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
         f = udf(lambda a, b: a == b, BooleanType())
@@ -236,7 +229,6 @@ def test_udf_in_left_semi_join_condition(self):
     def test_udf_and_common_filter_in_join_condition(self):
         # regression test for SPARK-25314
         # test the complex scenario with both udf and common filter
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
         right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
         f = udf(lambda a, b: a == b, BooleanType())
@@ -247,7 +239,6 @@ def test_udf_and_common_filter_in_join_condition(self):
     def test_udf_and_common_filter_in_left_semi_join_condition(self):
         # regression test for SPARK-25314
         # test the complex scenario with both udf and common filter
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
         right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
         f = udf(lambda a, b: a == b, BooleanType())
@@ -258,7 +249,6 @@ def test_udf_and_common_filter_in_left_semi_join_condition(self):
     def test_udf_not_supported_in_join_condition(self):
         # regression test for SPARK-25314
         # test python udf is not supported in join type besides left_semi and inner join.
-        from pyspark.sql.functions import udf
         left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
         right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
         f = udf(lambda a, b: a == b, BooleanType())
@@ -301,7 +291,7 @@ def test_broadcast_in_udf(self):
 
     def test_udf_with_filter_function(self):
         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
-        from pyspark.sql.functions import udf, col
+        from pyspark.sql.functions import col
         from pyspark.sql.types import BooleanType
 
         my_filter = udf(lambda a: a < 2, BooleanType())
@@ -310,7 +300,7 @@ def test_udf_with_filter_function(self):
 
     def test_udf_with_aggregate_function(self):
         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
-        from pyspark.sql.functions import udf, col, sum
+        from pyspark.sql.functions import col, sum
         from pyspark.sql.types import BooleanType
 
         my_filter = udf(lambda a: a == 1, BooleanType())
@@ -326,7 +316,7 @@ def test_udf_with_aggregate_function(self):
         self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
 
     def test_udf_in_generate(self):
-        from pyspark.sql.functions import udf, explode
+        from pyspark.sql.functions import explode
         df = self.spark.range(5)
         f = udf(lambda x: list(range(x)), ArrayType(LongType()))
         row = df.select(explode(f(*df))).groupBy().sum().first()
@@ -353,7 +343,6 @@ def test_udf_in_generate(self):
         self.assertEqual(res[3][1], 1)
 
     def test_udf_with_order_by_and_limit(self):
-        from pyspark.sql.functions import udf
         my_copy = udf(lambda x: x, IntegerType())
         df = self.spark.range(10).orderBy("id")
         res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
@@ -394,14 +383,14 @@ def test_non_existed_udaf(self):
                                 lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
 
     def test_udf_with_input_file_name(self):
-        from pyspark.sql.functions import udf, input_file_name
+        from pyspark.sql.functions import input_file_name
         sourceFile = udf(lambda path: path, StringType())
         filePath = "python/test_support/sql/people1.json"
         row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
         self.assertTrue(row[0].find("people1.json") != -1)
 
     def test_udf_with_input_file_name_for_hadooprdd(self):
-        from pyspark.sql.functions import udf, input_file_name
+        from pyspark.sql.functions import input_file_name
 
         def filename(path):
             return path
@@ -427,9 +416,6 @@ def test_udf_defers_judf_initialization(self):
         # This is separate of  UDFInitializationTests
         # to avoid context initialization
         # when udf is called
-
-        from pyspark.sql.functions import UserDefinedFunction
-
         f = UserDefinedFunction(lambda x: x, StringType())
 
         self.assertIsNone(
@@ -445,8 +431,6 @@ def test_udf_defers_judf_initialization(self):
         )
 
     def test_udf_with_string_return_type(self):
-        from pyspark.sql.functions import UserDefinedFunction
-
         add_one = UserDefinedFunction(lambda x: x + 1, "integer")
         make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
         make_array = UserDefinedFunction(
@@ -460,13 +444,11 @@ def test_udf_with_string_return_type(self):
         self.assertTupleEqual(expected, actual)
 
     def test_udf_shouldnt_accept_noncallable_object(self):
-        from pyspark.sql.functions import UserDefinedFunction
-
         non_callable = None
         self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
 
     def test_udf_with_decorator(self):
-        from pyspark.sql.functions import lit, udf
+        from pyspark.sql.functions import lit
         from pyspark.sql.types import IntegerType, DoubleType
 
         @udf(IntegerType())
@@ -523,7 +505,6 @@ def as_double(x):
         )
 
     def test_udf_wrapper(self):
-        from pyspark.sql.functions import udf
         from pyspark.sql.types import IntegerType
 
         def f(x):
@@ -569,7 +550,7 @@ def test_nonparam_udf_with_aggregate(self):
     # SPARK-24721
     @unittest.skipIf(not test_compiled, test_not_compiled_message)
     def test_datasource_with_udf(self):
-        from pyspark.sql.functions import udf, lit, col
+        from pyspark.sql.functions import lit, col
 
         path = tempfile.mkdtemp()
         shutil.rmtree(path)
@@ -609,8 +590,6 @@ def test_datasource_with_udf(self):
 
     # SPARK-25591
     def test_same_accumulator_in_udfs(self):
-        from pyspark.sql.functions import udf
-
         data_schema = StructType([StructField("a", IntegerType(), True),
                                   StructField("b", IntegerType(), True)])
         data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
@@ -632,6 +611,15 @@ def second_udf(x):
         data.collect()
         self.assertEqual(test_accum.value, 101)
 
+    # SPARK-26293
+    def test_udf_in_subquery(self):
+        f = udf(lambda x: x, "long")
+        with self.tempView("v"):
+            self.spark.range(1).filter(f("id") >= 0).createTempView("v")
+            sql = self.spark.sql
+            result = sql("select i from values(0L) as data(i) where i in (select id from v)")
+            self.assertEqual(result.collect(), [Row(i=0)])
+
 
 class UDFInitializationTests(unittest.TestCase):
     def tearDown(self):
@@ -642,8 +630,6 @@ def tearDown(self):
             SparkContext._active_spark_context.stop()
 
     def test_udf_init_shouldnt_initialize_context(self):
-        from pyspark.sql.functions import UserDefinedFunction
-
         UserDefinedFunction(lambda x: x, StringType())
 
         self.assertIsNone(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 2b87796dc6833..a5203daea9cd0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -60,8 +60,12 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
 /**
  * A logical plan that evaluates a [[PythonUDF]].
  */
-case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
-  extends UnaryNode
+case class ArrowEvalPython(
+    udfs: Seq[PythonUDF],
+    output: Seq[Attribute],
+    child: LogicalPlan) extends UnaryNode {
+  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+}
 
 /**
  * A physical plan that evaluates a [[PythonUDF]].
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index b08b7e60e130b..d3736d24e5019 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -32,8 +32,12 @@ import org.apache.spark.sql.types.{StructField, StructType}
 /**
  * A logical plan that evaluates a [[PythonUDF]]
  */
-case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
-  extends UnaryNode
+case class BatchEvalPython(
+    udfs: Seq[PythonUDF],
+    output: Seq[Attribute],
+    child: LogicalPlan) extends UnaryNode {
+  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+}
 
 /**
  * A physical plan that evaluates a [[PythonUDF]]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 90b5325919e96..380c31baa6213 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -24,7 +24,7 @@ import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 
 
@@ -131,8 +131,20 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
     expressions.flatMap(collectEvaluableUDFs)
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-    case plan: LogicalPlan => extract(plan)
+  def apply(plan: LogicalPlan): LogicalPlan = plan match {
+    // SPARK-26293: A subquery will be rewritten into join later, and will go through this rule
+    // eventually. Here we skip subquery, as Python UDF only needs to be extracted once.
+    case _: Subquery => plan
+
+    case _ => plan transformUp {
+      // A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and
+      // `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't
+      // extract Python UDFs from them.
+      case p: BatchEvalPython => p
+      case p: ArrowEvalPython => p
+
+      case plan: LogicalPlan => extract(plan)
+    }
   }
 
   /**


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org