You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/07/28 05:41:14 UTC
spark git commit: [SPARK-24624][SQL][PYTHON] Support mixture of
Python UDF and Scalar Pandas UDF
Repository: spark
Updated Branches:
refs/heads/master 6424b146c -> e8752095a
[SPARK-24624][SQL][PYTHON] Support mixture of Python UDF and Scalar Pandas UDF
## What changes were proposed in this pull request?
This PR add supports for using mixed Python UDF and Scalar Pandas UDF, in the following two cases:
(1)
```
from pyspark.sql.functions import udf, pandas_udf
udf('int')
def f1(x):
return x + 1
pandas_udf('int')
def f2(x):
return x + 1
df = spark.range(0, 1).toDF('v') \
.withColumn('foo', f1(col('v'))) \
.withColumn('bar', f2(col('v')))
```
QueryPlan:
```
>>> df.explain(True)
== Parsed Logical Plan ==
'Project [v#2L, foo#5, f2('v) AS bar#9]
+- AnalysisBarrier
+- Project [v#2L, f1(v#2L) AS foo#5]
+- Project [id#0L AS v#2L]
+- Range (0, 1, step=1, splits=Some(4))
== Analyzed Logical Plan ==
v: bigint, foo: int, bar: int
Project [v#2L, foo#5, f2(v#2L) AS bar#9]
+- Project [v#2L, f1(v#2L) AS foo#5]
+- Project [id#0L AS v#2L]
+- Range (0, 1, step=1, splits=Some(4))
== Optimized Logical Plan ==
Project [id#0L AS v#2L, f1(id#0L) AS foo#5, f2(id#0L) AS bar#9]
+- Range (0, 1, step=1, splits=Some(4))
== Physical Plan ==
*(2) Project [id#0L AS v#2L, pythonUDF0#13 AS foo#5, pythonUDF0#14 AS bar#9]
+- ArrowEvalPython [f2(id#0L)], [id#0L, pythonUDF0#13, pythonUDF0#14]
+- BatchEvalPython [f1(id#0L)], [id#0L, pythonUDF0#13]
+- *(1) Range (0, 1, step=1, splits=4)
```
(2)
```
from pyspark.sql.functions import udf, pandas_udf
udf('int')
def f1(x):
return x + 1
pandas_udf('int')
def f2(x):
return x + 1
df = spark.range(0, 1).toDF('v')
df = df.withColumn('foo', f2(f1(df['v'])))
```
QueryPlan:
```
>>> df.explain(True)
== Parsed Logical Plan ==
Project [v#21L, f2(f1(v#21L)) AS foo#46]
+- AnalysisBarrier
+- Project [v#21L, f1(f2(v#21L)) AS foo#39]
+- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#32]
+- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#25]
+- Project [id#19L AS v#21L]
+- Range (0, 1, step=1, splits=Some(4))
== Analyzed Logical Plan ==
v: bigint, foo: int
Project [v#21L, f2(f1(v#21L)) AS foo#46]
+- Project [v#21L, f1(f2(v#21L)) AS foo#39]
+- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#32]
+- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#25]
+- Project [id#19L AS v#21L]
+- Range (0, 1, step=1, splits=Some(4))
== Optimized Logical Plan ==
Project [id#19L AS v#21L, f2(f1(id#19L)) AS foo#46]
+- Range (0, 1, step=1, splits=Some(4))
== Physical Plan ==
*(2) Project [id#19L AS v#21L, pythonUDF0#50 AS foo#46]
+- ArrowEvalPython [f2(pythonUDF0#49)], [id#19L, pythonUDF0#49, pythonUDF0#50]
+- BatchEvalPython [f1(id#19L)], [id#19L, pythonUDF0#49]
+- *(1) Range (0, 1, step=1, splits=4)
```
## How was this patch tested?
New tests are added to BatchEvalPythonExecSuite and ScalarPandasUDFTests
Author: Li Jin <ic...@gmail.com>
Closes #21650 from icexelloss/SPARK-24624-mix-udf.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e8752095
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e8752095
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e8752095
Branch: refs/heads/master
Commit: e8752095a00aba453a92bc822131c001602f0829
Parents: 6424b14
Author: Li Jin <ic...@gmail.com>
Authored: Sat Jul 28 13:41:07 2018 +0800
Committer: hyukjinkwon <gu...@apache.org>
Committed: Sat Jul 28 13:41:07 2018 +0800
----------------------------------------------------------------------
python/pyspark/sql/tests.py | 186 +++++++++++++++++--
.../execution/python/ExtractPythonUDFs.scala | 42 +++--
.../python/BatchEvalPythonExecSuite.scala | 7 +
.../python/ExtractPythonUDFsSuite.scala | 92 +++++++++
4 files changed, 304 insertions(+), 23 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/e8752095/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2d6b9f0..a294d70 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -4763,17 +4763,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
'Result vector from pandas_udf was not the required length'):
df.select(raise_exception(col('id'))).collect()
- def test_vectorized_udf_mix_udf(self):
- from pyspark.sql.functions import pandas_udf, udf, col
- df = self.spark.range(10)
- row_by_row_udf = udf(lambda x: x, LongType())
- pd_udf = pandas_udf(lambda x: x, LongType())
- with QuietTest(self.sc):
- with self.assertRaisesRegexp(
- Exception,
- 'Can not mix vectorized and non-vectorized UDFs'):
- df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect()
-
def test_vectorized_udf_chained(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
@@ -5060,6 +5049,166 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
self.assertEqual(df.first()[0], 0)
+ def test_mixed_udf(self):
+ import pandas as pd
+ from pyspark.sql.functions import col, udf, pandas_udf
+
+ df = self.spark.range(0, 1).toDF('v')
+
+ # Test mixture of multiple UDFs and Pandas UDFs.
+
+ @udf('int')
+ def f1(x):
+ assert type(x) == int
+ return x + 1
+
+ @pandas_udf('int')
+ def f2(x):
+ assert type(x) == pd.Series
+ return x + 10
+
+ @udf('int')
+ def f3(x):
+ assert type(x) == int
+ return x + 100
+
+ @pandas_udf('int')
+ def f4(x):
+ assert type(x) == pd.Series
+ return x + 1000
+
+ # Test single expression with chained UDFs
+ df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v'])))
+ df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
+ df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v'])))))
+ df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
+ df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))
+
+ expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11)
+ expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111)
+ expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111)
+ expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011)
+ expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101)
+
+ self.assertEquals(expected_chained_1.collect(), df_chained_1.collect())
+ self.assertEquals(expected_chained_2.collect(), df_chained_2.collect())
+ self.assertEquals(expected_chained_3.collect(), df_chained_3.collect())
+ self.assertEquals(expected_chained_4.collect(), df_chained_4.collect())
+ self.assertEquals(expected_chained_5.collect(), df_chained_5.collect())
+
+ # Test multiple mixed UDF expressions in a single projection
+ df_multi_1 = df \
+ .withColumn('f1', f1(col('v'))) \
+ .withColumn('f2', f2(col('v'))) \
+ .withColumn('f3', f3(col('v'))) \
+ .withColumn('f4', f4(col('v'))) \
+ .withColumn('f2_f1', f2(col('f1'))) \
+ .withColumn('f3_f1', f3(col('f1'))) \
+ .withColumn('f4_f1', f4(col('f1'))) \
+ .withColumn('f3_f2', f3(col('f2'))) \
+ .withColumn('f4_f2', f4(col('f2'))) \
+ .withColumn('f4_f3', f4(col('f3'))) \
+ .withColumn('f3_f2_f1', f3(col('f2_f1'))) \
+ .withColumn('f4_f2_f1', f4(col('f2_f1'))) \
+ .withColumn('f4_f3_f1', f4(col('f3_f1'))) \
+ .withColumn('f4_f3_f2', f4(col('f3_f2'))) \
+ .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1')))
+
+ # Test mixed udfs in a single expression
+ df_multi_2 = df \
+ .withColumn('f1', f1(col('v'))) \
+ .withColumn('f2', f2(col('v'))) \
+ .withColumn('f3', f3(col('v'))) \
+ .withColumn('f4', f4(col('v'))) \
+ .withColumn('f2_f1', f2(f1(col('v')))) \
+ .withColumn('f3_f1', f3(f1(col('v')))) \
+ .withColumn('f4_f1', f4(f1(col('v')))) \
+ .withColumn('f3_f2', f3(f2(col('v')))) \
+ .withColumn('f4_f2', f4(f2(col('v')))) \
+ .withColumn('f4_f3', f4(f3(col('v')))) \
+ .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \
+ .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \
+ .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \
+ .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
+ .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))
+
+ expected = df \
+ .withColumn('f1', df['v'] + 1) \
+ .withColumn('f2', df['v'] + 10) \
+ .withColumn('f3', df['v'] + 100) \
+ .withColumn('f4', df['v'] + 1000) \
+ .withColumn('f2_f1', df['v'] + 11) \
+ .withColumn('f3_f1', df['v'] + 101) \
+ .withColumn('f4_f1', df['v'] + 1001) \
+ .withColumn('f3_f2', df['v'] + 110) \
+ .withColumn('f4_f2', df['v'] + 1010) \
+ .withColumn('f4_f3', df['v'] + 1100) \
+ .withColumn('f3_f2_f1', df['v'] + 111) \
+ .withColumn('f4_f2_f1', df['v'] + 1011) \
+ .withColumn('f4_f3_f1', df['v'] + 1101) \
+ .withColumn('f4_f3_f2', df['v'] + 1110) \
+ .withColumn('f4_f3_f2_f1', df['v'] + 1111)
+
+ self.assertEquals(expected.collect(), df_multi_1.collect())
+ self.assertEquals(expected.collect(), df_multi_2.collect())
+
+ def test_mixed_udf_and_sql(self):
+ import pandas as pd
+ from pyspark.sql import Column
+ from pyspark.sql.functions import udf, pandas_udf
+
+ df = self.spark.range(0, 1).toDF('v')
+
+ # Test mixture of UDFs, Pandas UDFs and SQL expression.
+
+ @udf('int')
+ def f1(x):
+ assert type(x) == int
+ return x + 1
+
+ def f2(x):
+ assert type(x) == Column
+ return x + 10
+
+ @pandas_udf('int')
+ def f3(x):
+ assert type(x) == pd.Series
+ return x + 100
+
+ df1 = df.withColumn('f1', f1(df['v'])) \
+ .withColumn('f2', f2(df['v'])) \
+ .withColumn('f3', f3(df['v'])) \
+ .withColumn('f1_f2', f1(f2(df['v']))) \
+ .withColumn('f1_f3', f1(f3(df['v']))) \
+ .withColumn('f2_f1', f2(f1(df['v']))) \
+ .withColumn('f2_f3', f2(f3(df['v']))) \
+ .withColumn('f3_f1', f3(f1(df['v']))) \
+ .withColumn('f3_f2', f3(f2(df['v']))) \
+ .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \
+ .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \
+ .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \
+ .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \
+ .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
+ .withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
+
+ expected = df.withColumn('f1', df['v'] + 1) \
+ .withColumn('f2', df['v'] + 10) \
+ .withColumn('f3', df['v'] + 100) \
+ .withColumn('f1_f2', df['v'] + 11) \
+ .withColumn('f1_f3', df['v'] + 101) \
+ .withColumn('f2_f1', df['v'] + 11) \
+ .withColumn('f2_f3', df['v'] + 110) \
+ .withColumn('f3_f1', df['v'] + 101) \
+ .withColumn('f3_f2', df['v'] + 110) \
+ .withColumn('f1_f2_f3', df['v'] + 111) \
+ .withColumn('f1_f3_f2', df['v'] + 111) \
+ .withColumn('f2_f1_f3', df['v'] + 111) \
+ .withColumn('f2_f3_f1', df['v'] + 111) \
+ .withColumn('f3_f1_f2', df['v'] + 111) \
+ .withColumn('f3_f2_f1', df['v'] + 111)
+
+ self.assertEquals(expected.collect(), df1.collect())
+
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
@@ -5487,6 +5636,21 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
F.col('temp0.key') == F.col('temp1.key'))
self.assertEquals(res.count(), 5)
+ def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
+ import pandas as pd
+ from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
+
+ df = self.spark.range(0, 10).toDF('v1')
+ df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
+ .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
+
+ result = df.groupby() \
+ .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]),
+ 'sum int',
+ PandasUDFType.GROUPED_MAP))
+
+ self.assertEquals(result.collect()[0]['sum'], 165)
+
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
http://git-wip-us.apache.org/repos/asf/spark/blob/e8752095/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 1e09610..cb75874 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
@@ -94,28 +95,44 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
- private def hasPythonUDF(e: Expression): Boolean = {
+ private type EvalType = Int
+ private type EvalTypeChecker = EvalType => Boolean
+
+ private def hasScalarPythonUDF(e: Expression): Boolean = {
e.find(PythonUDF.isScalarPythonUDF).isDefined
}
private def canEvaluateInPython(e: PythonUDF): Boolean = {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
- case Seq(u: PythonUDF) => canEvaluateInPython(u)
+ case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u)
// Python UDF can't be evaluated directly in JVM
- case children => !children.exists(hasPythonUDF)
+ case children => !children.exists(hasScalarPythonUDF)
}
}
- private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match {
- case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf)
- case e => e.children.flatMap(collectEvaluatableUDF)
+ private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = {
+ // Eval type checker is set once when we find the first evaluable UDF and its value
+ // shouldn't change later.
+ // Used to check if subsequent UDFs are of the same type as the first UDF. (since we can only
+ // extract UDFs of the same eval type)
+ var evalTypeChecker: Option[EvalTypeChecker] = None
+
+ def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
+ case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
+ && evalTypeChecker.isEmpty =>
+ evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType)
+ Seq(udf)
+ case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
+ && evalTypeChecker.get(udf.evalType) =>
+ Seq(udf)
+ case e => e.children.flatMap(collectEvaluableUDFs)
+ }
+
+ expressions.flatMap(collectEvaluableUDFs)
}
def apply(plan: SparkPlan): SparkPlan = plan transformUp {
- // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker
- // Therefore we don't need to extract the UDFs
- case plan: FlatMapGroupsInPandasExec => plan
case plan: SparkPlan => extract(plan)
}
@@ -123,7 +140,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
private def extract(plan: SparkPlan): SparkPlan = {
- val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+ val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
if (udfs.isEmpty) {
@@ -167,7 +184,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
case _ =>
- throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs")
+ throw new AnalysisException(
+ "Expected either Scalar Pandas UDFs or Batched UDFs but got both")
}
attributeMap ++= validUdfs.zip(resultAttrs)
@@ -205,7 +223,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
case filter: FilterExec =>
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
- val (pushDown, rest) = candidates.partition(!hasPythonUDF(_))
+ val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
if (pushDown.nonEmpty) {
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)
http://git-wip-us.apache.org/repos/asf/spark/blob/e8752095/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index d456c93..2cc55ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -115,3 +115,10 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
dataType = BooleanType,
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)
+
+class MyDummyScalarPandasUDF extends UserDefinedPythonFunction(
+ name = "dummyScalarPandasUDF",
+ func = new DummyUDF,
+ dataType = BooleanType,
+ pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+ udfDeterministic = true)
http://git-wip-us.apache.org/repos/asf/spark/blob/e8752095/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
new file mode 100644
index 0000000..76b609d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext {
+ import testImplicits.newProductEncoder
+ import testImplicits.localSeqToDatasetHolder
+
+ val batchedPythonUDF = new MyDummyPythonUDF
+ val scalarPandasUDF = new MyDummyScalarPandasUDF
+
+ private def collectBatchExec(plan: SparkPlan): Seq[BatchEvalPythonExec] = plan.collect {
+ case b: BatchEvalPythonExec => b
+ }
+
+ private def collectArrowExec(plan: SparkPlan): Seq[ArrowEvalPythonExec] = plan.collect {
+ case b: ArrowEvalPythonExec => b
+ }
+
+ test("Chained Batched Python UDFs should be combined to a single physical node") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ val df2 = df.withColumn("c", batchedPythonUDF(col("a")))
+ .withColumn("d", batchedPythonUDF(col("c")))
+ val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
+ assert(pythonEvalNodes.size == 1)
+ }
+
+ test("Chained Scalar Pandas UDFs should be combined to a single physical node") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ val df2 = df.withColumn("c", scalarPandasUDF(col("a")))
+ .withColumn("d", scalarPandasUDF(col("c")))
+ val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
+ assert(arrowEvalNodes.size == 1)
+ }
+
+ test("Mixed Batched Python UDFs and Pandas UDF should be separate physical node") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ val df2 = df.withColumn("c", batchedPythonUDF(col("a")))
+ .withColumn("d", scalarPandasUDF(col("b")))
+
+ val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
+ val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
+ assert(pythonEvalNodes.size == 1)
+ assert(arrowEvalNodes.size == 1)
+ }
+
+ test("Independent Batched Python UDFs and Scalar Pandas UDFs should be combined separately") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ val df2 = df.withColumn("c1", batchedPythonUDF(col("a")))
+ .withColumn("c2", batchedPythonUDF(col("c1")))
+ .withColumn("d1", scalarPandasUDF(col("a")))
+ .withColumn("d2", scalarPandasUDF(col("d1")))
+
+ val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
+ val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
+ assert(pythonEvalNodes.size == 1)
+ assert(arrowEvalNodes.size == 1)
+ }
+
+ test("Dependent Batched Python UDFs and Scalar Pandas UDFs should not be combined") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ val df2 = df.withColumn("c1", batchedPythonUDF(col("a")))
+ .withColumn("d1", scalarPandasUDF(col("c1")))
+ .withColumn("c2", batchedPythonUDF(col("d1")))
+ .withColumn("d2", scalarPandasUDF(col("c2")))
+
+ val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
+ val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
+ assert(pythonEvalNodes.size == 2)
+ assert(arrowEvalNodes.size == 2)
+ }
+}
+
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org