You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/07/28 05:41:14 UTC

spark git commit: [SPARK-24624][SQL][PYTHON] Support mixture of Python UDF and Scalar Pandas UDF

Repository: spark
Updated Branches:
  refs/heads/master 6424b146c -> e8752095a


[SPARK-24624][SQL][PYTHON] Support mixture of Python UDF and Scalar Pandas UDF

## What changes were proposed in this pull request?

This PR add supports for using mixed Python UDF and Scalar Pandas UDF, in the following two cases:

(1)
```
from pyspark.sql.functions import udf, pandas_udf

udf('int')
def f1(x):
    return x + 1

pandas_udf('int')
def f2(x):
    return x + 1

df = spark.range(0, 1).toDF('v') \
    .withColumn('foo', f1(col('v'))) \
    .withColumn('bar', f2(col('v')))

```

QueryPlan:
```
>>> df.explain(True)
== Parsed Logical Plan ==
'Project [v#2L, foo#5, f2('v) AS bar#9]
+- AnalysisBarrier
      +- Project [v#2L, f1(v#2L) AS foo#5]
         +- Project [id#0L AS v#2L]
            +- Range (0, 1, step=1, splits=Some(4))

== Analyzed Logical Plan ==
v: bigint, foo: int, bar: int
Project [v#2L, foo#5, f2(v#2L) AS bar#9]
+- Project [v#2L, f1(v#2L) AS foo#5]
   +- Project [id#0L AS v#2L]
      +- Range (0, 1, step=1, splits=Some(4))

== Optimized Logical Plan ==
Project [id#0L AS v#2L, f1(id#0L) AS foo#5, f2(id#0L) AS bar#9]
+- Range (0, 1, step=1, splits=Some(4))

== Physical Plan ==
*(2) Project [id#0L AS v#2L, pythonUDF0#13 AS foo#5, pythonUDF0#14 AS bar#9]
+- ArrowEvalPython [f2(id#0L)], [id#0L, pythonUDF0#13, pythonUDF0#14]
   +- BatchEvalPython [f1(id#0L)], [id#0L, pythonUDF0#13]
      +- *(1) Range (0, 1, step=1, splits=4)
```

(2)
```
from pyspark.sql.functions import udf, pandas_udf
udf('int')
def f1(x):
    return x + 1

pandas_udf('int')
def f2(x):
    return x + 1

df = spark.range(0, 1).toDF('v')
df = df.withColumn('foo', f2(f1(df['v'])))
```

QueryPlan:
```
>>> df.explain(True)
== Parsed Logical Plan ==
Project [v#21L, f2(f1(v#21L)) AS foo#46]
+- AnalysisBarrier
      +- Project [v#21L, f1(f2(v#21L)) AS foo#39]
         +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#32]
            +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#25]
               +- Project [id#19L AS v#21L]
                  +- Range (0, 1, step=1, splits=Some(4))

== Analyzed Logical Plan ==
v: bigint, foo: int
Project [v#21L, f2(f1(v#21L)) AS foo#46]
+- Project [v#21L, f1(f2(v#21L)) AS foo#39]
   +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#32]
      +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#25]
         +- Project [id#19L AS v#21L]
            +- Range (0, 1, step=1, splits=Some(4))

== Optimized Logical Plan ==
Project [id#19L AS v#21L, f2(f1(id#19L)) AS foo#46]
+- Range (0, 1, step=1, splits=Some(4))

== Physical Plan ==
*(2) Project [id#19L AS v#21L, pythonUDF0#50 AS foo#46]
+- ArrowEvalPython [f2(pythonUDF0#49)], [id#19L, pythonUDF0#49, pythonUDF0#50]
   +- BatchEvalPython [f1(id#19L)], [id#19L, pythonUDF0#49]
      +- *(1) Range (0, 1, step=1, splits=4)
```

## How was this patch tested?

New tests are added to BatchEvalPythonExecSuite and ScalarPandasUDFTests

Author: Li Jin <ic...@gmail.com>

Closes #21650 from icexelloss/SPARK-24624-mix-udf.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e8752095
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e8752095
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e8752095

Branch: refs/heads/master
Commit: e8752095a00aba453a92bc822131c001602f0829
Parents: 6424b14
Author: Li Jin <ic...@gmail.com>
Authored: Sat Jul 28 13:41:07 2018 +0800
Committer: hyukjinkwon <gu...@apache.org>
Committed: Sat Jul 28 13:41:07 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     | 186 +++++++++++++++++--
 .../execution/python/ExtractPythonUDFs.scala    |  42 +++--
 .../python/BatchEvalPythonExecSuite.scala       |   7 +
 .../python/ExtractPythonUDFsSuite.scala         |  92 +++++++++
 4 files changed, 304 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e8752095/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2d6b9f0..a294d70 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -4763,17 +4763,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
                     'Result vector from pandas_udf was not the required length'):
                 df.select(raise_exception(col('id'))).collect()
 
-    def test_vectorized_udf_mix_udf(self):
-        from pyspark.sql.functions import pandas_udf, udf, col
-        df = self.spark.range(10)
-        row_by_row_udf = udf(lambda x: x, LongType())
-        pd_udf = pandas_udf(lambda x: x, LongType())
-        with QuietTest(self.sc):
-            with self.assertRaisesRegexp(
-                    Exception,
-                    'Can not mix vectorized and non-vectorized UDFs'):
-                df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect()
-
     def test_vectorized_udf_chained(self):
         from pyspark.sql.functions import pandas_udf, col
         df = self.spark.range(10)
@@ -5060,6 +5049,166 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
         self.assertEqual(df.first()[0], 0)
 
+    def test_mixed_udf(self):
+        import pandas as pd
+        from pyspark.sql.functions import col, udf, pandas_udf
+
+        df = self.spark.range(0, 1).toDF('v')
+
+        # Test mixture of multiple UDFs and Pandas UDFs.
+
+        @udf('int')
+        def f1(x):
+            assert type(x) == int
+            return x + 1
+
+        @pandas_udf('int')
+        def f2(x):
+            assert type(x) == pd.Series
+            return x + 10
+
+        @udf('int')
+        def f3(x):
+            assert type(x) == int
+            return x + 100
+
+        @pandas_udf('int')
+        def f4(x):
+            assert type(x) == pd.Series
+            return x + 1000
+
+        # Test single expression with chained UDFs
+        df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v'])))
+        df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
+        df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v'])))))
+        df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
+        df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))
+
+        expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11)
+        expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111)
+        expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111)
+        expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011)
+        expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101)
+
+        self.assertEquals(expected_chained_1.collect(), df_chained_1.collect())
+        self.assertEquals(expected_chained_2.collect(), df_chained_2.collect())
+        self.assertEquals(expected_chained_3.collect(), df_chained_3.collect())
+        self.assertEquals(expected_chained_4.collect(), df_chained_4.collect())
+        self.assertEquals(expected_chained_5.collect(), df_chained_5.collect())
+
+        # Test multiple mixed UDF expressions in a single projection
+        df_multi_1 = df \
+            .withColumn('f1', f1(col('v'))) \
+            .withColumn('f2', f2(col('v'))) \
+            .withColumn('f3', f3(col('v'))) \
+            .withColumn('f4', f4(col('v'))) \
+            .withColumn('f2_f1', f2(col('f1'))) \
+            .withColumn('f3_f1', f3(col('f1'))) \
+            .withColumn('f4_f1', f4(col('f1'))) \
+            .withColumn('f3_f2', f3(col('f2'))) \
+            .withColumn('f4_f2', f4(col('f2'))) \
+            .withColumn('f4_f3', f4(col('f3'))) \
+            .withColumn('f3_f2_f1', f3(col('f2_f1'))) \
+            .withColumn('f4_f2_f1', f4(col('f2_f1'))) \
+            .withColumn('f4_f3_f1', f4(col('f3_f1'))) \
+            .withColumn('f4_f3_f2', f4(col('f3_f2'))) \
+            .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1')))
+
+        # Test mixed udfs in a single expression
+        df_multi_2 = df \
+            .withColumn('f1', f1(col('v'))) \
+            .withColumn('f2', f2(col('v'))) \
+            .withColumn('f3', f3(col('v'))) \
+            .withColumn('f4', f4(col('v'))) \
+            .withColumn('f2_f1', f2(f1(col('v')))) \
+            .withColumn('f3_f1', f3(f1(col('v')))) \
+            .withColumn('f4_f1', f4(f1(col('v')))) \
+            .withColumn('f3_f2', f3(f2(col('v')))) \
+            .withColumn('f4_f2', f4(f2(col('v')))) \
+            .withColumn('f4_f3', f4(f3(col('v')))) \
+            .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \
+            .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \
+            .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \
+            .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
+            .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))
+
+        expected = df \
+            .withColumn('f1', df['v'] + 1) \
+            .withColumn('f2', df['v'] + 10) \
+            .withColumn('f3', df['v'] + 100) \
+            .withColumn('f4', df['v'] + 1000) \
+            .withColumn('f2_f1', df['v'] + 11) \
+            .withColumn('f3_f1', df['v'] + 101) \
+            .withColumn('f4_f1', df['v'] + 1001) \
+            .withColumn('f3_f2', df['v'] + 110) \
+            .withColumn('f4_f2', df['v'] + 1010) \
+            .withColumn('f4_f3', df['v'] + 1100) \
+            .withColumn('f3_f2_f1', df['v'] + 111) \
+            .withColumn('f4_f2_f1', df['v'] + 1011) \
+            .withColumn('f4_f3_f1', df['v'] + 1101) \
+            .withColumn('f4_f3_f2', df['v'] + 1110) \
+            .withColumn('f4_f3_f2_f1', df['v'] + 1111)
+
+        self.assertEquals(expected.collect(), df_multi_1.collect())
+        self.assertEquals(expected.collect(), df_multi_2.collect())
+
+    def test_mixed_udf_and_sql(self):
+        import pandas as pd
+        from pyspark.sql import Column
+        from pyspark.sql.functions import udf, pandas_udf
+
+        df = self.spark.range(0, 1).toDF('v')
+
+        # Test mixture of UDFs, Pandas UDFs and SQL expression.
+
+        @udf('int')
+        def f1(x):
+            assert type(x) == int
+            return x + 1
+
+        def f2(x):
+            assert type(x) == Column
+            return x + 10
+
+        @pandas_udf('int')
+        def f3(x):
+            assert type(x) == pd.Series
+            return x + 100
+
+        df1 = df.withColumn('f1', f1(df['v'])) \
+            .withColumn('f2', f2(df['v'])) \
+            .withColumn('f3', f3(df['v'])) \
+            .withColumn('f1_f2', f1(f2(df['v']))) \
+            .withColumn('f1_f3', f1(f3(df['v']))) \
+            .withColumn('f2_f1', f2(f1(df['v']))) \
+            .withColumn('f2_f3', f2(f3(df['v']))) \
+            .withColumn('f3_f1', f3(f1(df['v']))) \
+            .withColumn('f3_f2', f3(f2(df['v']))) \
+            .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \
+            .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \
+            .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \
+            .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \
+            .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
+            .withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
+
+        expected = df.withColumn('f1', df['v'] + 1) \
+            .withColumn('f2', df['v'] + 10) \
+            .withColumn('f3', df['v'] + 100) \
+            .withColumn('f1_f2', df['v'] + 11) \
+            .withColumn('f1_f3', df['v'] + 101) \
+            .withColumn('f2_f1', df['v'] + 11) \
+            .withColumn('f2_f3', df['v'] + 110) \
+            .withColumn('f3_f1', df['v'] + 101) \
+            .withColumn('f3_f2', df['v'] + 110) \
+            .withColumn('f1_f2_f3', df['v'] + 111) \
+            .withColumn('f1_f3_f2', df['v'] + 111) \
+            .withColumn('f2_f1_f3', df['v'] + 111) \
+            .withColumn('f2_f3_f1', df['v'] + 111) \
+            .withColumn('f3_f1_f2', df['v'] + 111) \
+            .withColumn('f3_f2_f1', df['v'] + 111)
+
+        self.assertEquals(expected.collect(), df1.collect())
+
 
 @unittest.skipIf(
     not _have_pandas or not _have_pyarrow,
@@ -5487,6 +5636,21 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
                                                  F.col('temp0.key') == F.col('temp1.key'))
         self.assertEquals(res.count(), 5)
 
+    def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
+        import pandas as pd
+        from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
+
+        df = self.spark.range(0, 10).toDF('v1')
+        df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
+            .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
+
+        result = df.groupby() \
+            .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]),
+                              'sum int',
+                              PandasUDFType.GROUPED_MAP))
+
+        self.assertEquals(result.collect()[0]['sum'], 165)
+
 
 @unittest.skipIf(
     not _have_pandas or not _have_pyarrow,

http://git-wip-us.apache.org/repos/asf/spark/blob/e8752095/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 1e09610..cb75874 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
@@ -94,28 +95,44 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
  */
 object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
 
-  private def hasPythonUDF(e: Expression): Boolean = {
+  private type EvalType = Int
+  private type EvalTypeChecker = EvalType => Boolean
+
+  private def hasScalarPythonUDF(e: Expression): Boolean = {
     e.find(PythonUDF.isScalarPythonUDF).isDefined
   }
 
   private def canEvaluateInPython(e: PythonUDF): Boolean = {
     e.children match {
       // single PythonUDF child could be chained and evaluated in Python
-      case Seq(u: PythonUDF) => canEvaluateInPython(u)
+      case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u)
       // Python UDF can't be evaluated directly in JVM
-      case children => !children.exists(hasPythonUDF)
+      case children => !children.exists(hasScalarPythonUDF)
     }
   }
 
-  private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match {
-    case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf)
-    case e => e.children.flatMap(collectEvaluatableUDF)
+  private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = {
+    // Eval type checker is set once when we find the first evaluable UDF and its value
+    // shouldn't change later.
+    // Used to check if subsequent UDFs are of the same type as the first UDF. (since we can only
+    // extract UDFs of the same eval type)
+    var evalTypeChecker: Option[EvalTypeChecker] = None
+
+    def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
+      case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
+        && evalTypeChecker.isEmpty =>
+        evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType)
+        Seq(udf)
+      case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
+        && evalTypeChecker.get(udf.evalType) =>
+        Seq(udf)
+      case e => e.children.flatMap(collectEvaluableUDFs)
+    }
+
+    expressions.flatMap(collectEvaluableUDFs)
   }
 
   def apply(plan: SparkPlan): SparkPlan = plan transformUp {
-    // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker
-    // Therefore we don't need to extract the UDFs
-    case plan: FlatMapGroupsInPandasExec => plan
     case plan: SparkPlan => extract(plan)
   }
 
@@ -123,7 +140,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
    * Extract all the PythonUDFs from the current operator and evaluate them before the operator.
    */
   private def extract(plan: SparkPlan): SparkPlan = {
-    val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+    val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
       // ignore the PythonUDF that come from second/third aggregate, which is not used
       .filter(udf => udf.references.subsetOf(plan.inputSet))
     if (udfs.isEmpty) {
@@ -167,7 +184,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
             case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
               BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
             case _ =>
-              throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs")
+              throw new AnalysisException(
+                "Expected either Scalar Pandas UDFs or Batched UDFs but got both")
           }
 
           attributeMap ++= validUdfs.zip(resultAttrs)
@@ -205,7 +223,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
       case filter: FilterExec =>
         val (candidates, nonDeterministic) =
           splitConjunctivePredicates(filter.condition).partition(_.deterministic)
-        val (pushDown, rest) = candidates.partition(!hasPythonUDF(_))
+        val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
         if (pushDown.nonEmpty) {
           val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
           FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)

http://git-wip-us.apache.org/repos/asf/spark/blob/e8752095/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index d456c93..2cc55ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -115,3 +115,10 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
   dataType = BooleanType,
   pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
   udfDeterministic = true)
+
+class MyDummyScalarPandasUDF extends UserDefinedPythonFunction(
+  name = "dummyScalarPandasUDF",
+  func = new DummyUDF,
+  dataType = BooleanType,
+  pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+  udfDeterministic = true)

http://git-wip-us.apache.org/repos/asf/spark/blob/e8752095/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
new file mode 100644
index 0000000..76b609d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext {
+  import testImplicits.newProductEncoder
+  import testImplicits.localSeqToDatasetHolder
+
+  val batchedPythonUDF = new MyDummyPythonUDF
+  val scalarPandasUDF = new MyDummyScalarPandasUDF
+
+  private def collectBatchExec(plan: SparkPlan): Seq[BatchEvalPythonExec] = plan.collect {
+    case b: BatchEvalPythonExec => b
+  }
+
+  private def collectArrowExec(plan: SparkPlan): Seq[ArrowEvalPythonExec] = plan.collect {
+    case b: ArrowEvalPythonExec => b
+  }
+
+  test("Chained Batched Python UDFs should be combined to a single physical node") {
+    val df = Seq(("Hello", 4)).toDF("a", "b")
+    val df2 = df.withColumn("c", batchedPythonUDF(col("a")))
+      .withColumn("d", batchedPythonUDF(col("c")))
+    val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
+    assert(pythonEvalNodes.size == 1)
+  }
+
+  test("Chained Scalar Pandas UDFs should be combined to a single physical node") {
+    val df = Seq(("Hello", 4)).toDF("a", "b")
+    val df2 = df.withColumn("c", scalarPandasUDF(col("a")))
+      .withColumn("d", scalarPandasUDF(col("c")))
+    val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
+    assert(arrowEvalNodes.size == 1)
+  }
+
+  test("Mixed Batched Python UDFs and Pandas UDF should be separate physical node") {
+    val df = Seq(("Hello", 4)).toDF("a", "b")
+    val df2 = df.withColumn("c", batchedPythonUDF(col("a")))
+      .withColumn("d", scalarPandasUDF(col("b")))
+
+    val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
+    val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
+    assert(pythonEvalNodes.size == 1)
+    assert(arrowEvalNodes.size == 1)
+  }
+
+  test("Independent Batched Python UDFs and Scalar Pandas UDFs should be combined separately") {
+    val df = Seq(("Hello", 4)).toDF("a", "b")
+    val df2 = df.withColumn("c1", batchedPythonUDF(col("a")))
+      .withColumn("c2", batchedPythonUDF(col("c1")))
+      .withColumn("d1", scalarPandasUDF(col("a")))
+      .withColumn("d2", scalarPandasUDF(col("d1")))
+
+    val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
+    val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
+    assert(pythonEvalNodes.size == 1)
+    assert(arrowEvalNodes.size == 1)
+  }
+
+  test("Dependent Batched Python UDFs and Scalar Pandas UDFs should not be combined") {
+    val df = Seq(("Hello", 4)).toDF("a", "b")
+    val df2 = df.withColumn("c1", batchedPythonUDF(col("a")))
+      .withColumn("d1", scalarPandasUDF(col("c1")))
+      .withColumn("c2", batchedPythonUDF(col("d1")))
+      .withColumn("d2", scalarPandasUDF(col("c2")))
+
+    val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
+    val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
+    assert(pythonEvalNodes.size == 2)
+    assert(arrowEvalNodes.size == 2)
+  }
+}
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org