You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/13 02:17:39 UTC
[spark] branch master updated: [SPARK-43798][SQL][PYTHON] Support Python user-defined table functions
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 37ab190dc5b [SPARK-43798][SQL][PYTHON] Support Python user-defined table functions
37ab190dc5b is described below
commit 37ab190dc5bfa59b4e06af9551c35ab179a05733
Author: allisonwang-db <al...@databricks.com>
AuthorDate: Tue Jun 13 11:17:10 2023 +0900
[SPARK-43798][SQL][PYTHON] Support Python user-defined table functions
### What changes were proposed in this pull request?
This PR adds the initial support for Python user-defined table functions. It allows users to create UDTFs in PySpark and use them in PySpark and SQL.
Here are examples of creating and using Python UDTFs:
```python
# Implement the UDTF class
class TestUDTF:
def __init__(self):
...
def eval(self, *args):
yield "hello", "world"
def terminate(self):
...
# Create the UDTF
from pyspark.sql.functions import udtf
test_udtf = udtf(TestUDTF, returnType="c1: string, c2: string")
# Invoke the UDTF
test_udtf().show()
+-----+-----+
| c1| c2|
+-----+-----+
|hello|world|
+-----+-----+
# Register the UDTF
spark.udtf.register(name="test_udtf", f=test_udtf)
# Invoke the UDTF in SQL
spark.sql("SELECT * FROM test_udtf()").show()
+-----+-----+
| c1| c2|
+-----+-----+
|hello|world|
+-----+-----+
```
Please note that this is the initial PR, and there will be subsequent follow-up work to make it more user-friendly and performant.
### Why are the changes needed?
To support another type of user-defined function in PySpark.
### Does this PR introduce _any_ user-facing change?
Yes. After this PR, users can create Python user-defined table functions.
### How was this patch tested?
New unit tests.
Closes #41316 from allisonwang-db/spark-43798-py-udtf.
Authored-by: allisonwang-db <al...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../CheckConnectJvmClientCompatibility.scala | 1 +
.../org/apache/spark/api/python/PythonRunner.scala | 3 +
dev/sparktestsupport/modules.py | 2 +
.../source/reference/pyspark.sql/core_classes.rst | 2 +
.../source/reference/pyspark.sql/functions.rst | 1 +
python/docs/source/reference/pyspark.sql/index.rst | 1 +
.../source/reference/pyspark.sql/spark_session.rst | 1 +
.../pyspark.sql/{core_classes.rst => udtf.rst} | 26 +-
python/pyspark/rdd.py | 10 +-
python/pyspark/sql/__init__.py | 3 +-
python/pyspark/sql/_typing.pyi | 1 +
python/pyspark/sql/connect/session.py | 2 +-
python/pyspark/sql/context.py | 13 +
python/pyspark/sql/functions.py | 96 +++++
python/pyspark/sql/session.py | 15 +
.../sql/tests/connect/test_connect_basic.py | 11 +
python/pyspark/sql/tests/test_udtf.py | 413 +++++++++++++++++++++
python/pyspark/sql/udtf.py | 233 ++++++++++++
python/pyspark/worker.py | 84 ++++-
.../sql/catalyst/expressions/ExpressionInfo.java | 3 +-
.../spark/sql/catalyst/expressions/PythonUDF.scala | 39 +-
.../sql/catalyst/plans/logical/PlanHelper.scala | 3 +-
.../plans/logical/pythonLogicalOperators.scala | 35 +-
.../spark/sql/catalyst/trees/TreePatterns.scala | 1 +
.../scala/org/apache/spark/sql/SparkSession.scala | 2 +
.../org/apache/spark/sql/UDTFRegistration.scala | 49 +++
.../spark/sql/execution/SparkOptimizer.scala | 3 +-
.../spark/sql/execution/SparkStrategies.scala | 2 +
.../sql/execution/python/BatchEvalPythonExec.scala | 75 ++--
.../execution/python/BatchEvalPythonUDTFExec.scala | 175 +++++++++
.../sql/execution/python/ExtractPythonUDFs.scala | 16 +
.../python/UserDefinedPythonFunction.scala | 40 +-
.../sql/internal/BaseSessionStateBuilder.scala | 3 +
.../apache/spark/sql/internal/SessionState.scala | 3 +
.../apache/spark/sql/IntegratedUDFTestUtils.scala | 49 ++-
.../sql/execution/python/PythonUDTFSuite.scala | 78 ++++
36 files changed, 1427 insertions(+), 67 deletions(-)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 5e3401ebf50..7a9a889706d 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -225,6 +225,7 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udf"),
+ ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.streams"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"),
ProblemFilters.exclude[Problem](
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 912e76005f0..e5c42c721fe 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -56,6 +56,8 @@ private[spark] object PythonEvalType {
val SQL_MAP_ARROW_ITER_UDF = 207
val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
+ val SQL_TABLE_UDF = 300
+
def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
@@ -69,6 +71,7 @@ private[spark] object PythonEvalType {
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
+ case SQL_TABLE_UDF => "SQL_TABLE_UDF"
}
}
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 92d4e2b7fbf..82293cbabf0 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -456,6 +456,7 @@ pyspark_sql = Module(
"pyspark.sql.streaming.readwriter",
"pyspark.sql.streaming.listener",
"pyspark.sql.udf",
+ "pyspark.sql.udtf",
"pyspark.sql.window",
"pyspark.sql.avro.functions",
"pyspark.sql.protobuf.functions",
@@ -501,6 +502,7 @@ pyspark_sql = Module(
"pyspark.sql.tests.test_types",
"pyspark.sql.tests.test_udf",
"pyspark.sql.tests.test_udf_profiler",
+ "pyspark.sql.tests.test_udtf",
"pyspark.sql.tests.test_utils",
],
)
diff --git a/python/docs/source/reference/pyspark.sql/core_classes.rst b/python/docs/source/reference/pyspark.sql/core_classes.rst
index 90c5c412797..3cf19686cdd 100644
--- a/python/docs/source/reference/pyspark.sql/core_classes.rst
+++ b/python/docs/source/reference/pyspark.sql/core_classes.rst
@@ -39,4 +39,6 @@ Core Classes
DataFrameWriter
DataFrameWriterV2
UDFRegistration
+ UDTFRegistration
udf.UserDefinedFunction
+ udtf.UserDefinedTableFunction
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst
index 186ee0dce8d..fb39ded16c6 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -374,6 +374,7 @@ UDF
call_udf
pandas_udf
udf
+ udtf
unwrap_udt
Misc Functions
diff --git a/python/docs/source/reference/pyspark.sql/index.rst b/python/docs/source/reference/pyspark.sql/index.rst
index fc4569486a7..233c8b238a6 100644
--- a/python/docs/source/reference/pyspark.sql/index.rst
+++ b/python/docs/source/reference/pyspark.sql/index.rst
@@ -40,4 +40,5 @@ This page gives an overview of all public Spark SQL API.
avro
observation
udf
+ udtf
protobuf
diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst
index 15724306d75..edd5e746161 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -51,4 +51,5 @@ See also :class:`SparkSession`.
SparkSession.streams
SparkSession.table
SparkSession.udf
+ SparkSession.udtf
SparkSession.version
diff --git a/python/docs/source/reference/pyspark.sql/core_classes.rst b/python/docs/source/reference/pyspark.sql/udtf.rst
similarity index 72%
copy from python/docs/source/reference/pyspark.sql/core_classes.rst
copy to python/docs/source/reference/pyspark.sql/udtf.rst
index 90c5c412797..c251e101bbb 100644
--- a/python/docs/source/reference/pyspark.sql/core_classes.rst
+++ b/python/docs/source/reference/pyspark.sql/udtf.rst
@@ -16,27 +16,15 @@
under the License.
-============
-Core Classes
-============
+====
+UDTF
+====
+
.. currentmodule:: pyspark.sql
.. autosummary::
:toctree: api/
- SparkSession
- Catalog
- DataFrame
- Column
- Observation
- Row
- GroupedData
- PandasCogroupedOps
- DataFrameNaFunctions
- DataFrameStatFunctions
- Window
- DataFrameReader
- DataFrameWriter
- DataFrameWriterV2
- UDFRegistration
- udf.UserDefinedFunction
+ udtf.UserDefinedTableFunction.asNondeterministic
+ udtf.UserDefinedTableFunction.returnType
+ UDTFRegistration.register
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index e6ef7f6108e..37c06561f72 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -110,7 +110,13 @@ if TYPE_CHECKING:
)
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import AtomicType, StructType
- from pyspark.sql._typing import AtomicValue, RowLike, SQLArrowBatchedUDFType, SQLBatchedUDFType
+ from pyspark.sql._typing import (
+ AtomicValue,
+ RowLike,
+ SQLArrowBatchedUDFType,
+ SQLBatchedUDFType,
+ SQLTableUDFType,
+ )
from py4j.java_gateway import JavaObject
from py4j.java_collections import JavaArray
@@ -152,6 +158,8 @@ class PythonEvalType:
SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208
+ SQL_TABLE_UDF: "SQLTableUDFType" = 300
+
def portable_hash(x: Hashable) -> int:
"""
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 22149e8adb8..d0d69488fa4 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -40,7 +40,7 @@ Important classes of Spark SQL and DataFrames:
For working with window functions.
"""
from pyspark.sql.types import Row
-from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration
+from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration, UDTFRegistration
from pyspark.sql.session import SparkSession
from pyspark.sql.column import Column
from pyspark.sql.catalog import Catalog
@@ -57,6 +57,7 @@ __all__ = [
"SQLContext",
"HiveContext",
"UDFRegistration",
+ "UDTFRegistration",
"DataFrame",
"GroupedData",
"Column",
diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi
index aafd916dbb7..7b5c206c98f 100644
--- a/python/pyspark/sql/_typing.pyi
+++ b/python/pyspark/sql/_typing.pyi
@@ -58,6 +58,7 @@ RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row)
SQLBatchedUDFType = Literal[100]
SQLArrowBatchedUDFType = Literal[101]
+SQLTableUDFType = Literal[300]
class SupportsOpen(Protocol):
def open(self, partition_id: int, epoch_id: int) -> bool: ...
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 2b35ca3d7ea..ba0d8caaeca 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -592,7 +592,7 @@ class SparkSession:
raise PySparkAttributeError(
error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": name}
)
- elif name in ["newSession", "sparkContext"]:
+ elif name in ["newSession", "sparkContext", "udtf"]:
raise PySparkNotImplementedError(
error_class="NOT_IMPLEMENTED", message_parameters={"feature": f"{name}()"}
)
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 99f97977ccc..817c3b97337 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -41,6 +41,7 @@ from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.udf import UDFRegistration # noqa: F401
+from pyspark.sql.udtf import UDTFRegistration
from pyspark.errors.exceptions.captured import install_exception_handler
from pyspark.context import SparkContext
from pyspark.rdd import RDD
@@ -228,6 +229,18 @@ class SQLContext:
"""
return self.sparkSession.udf
+ @property
+ def udtf(self) -> UDTFRegistration:
+ """Returns a :class:`UDTFRegistration` for UDTF registration.
+
+ .. versionadded:: 3.5.0
+
+ Returns
+ -------
+ :class:`UDTFRegistration`
+ """
+ return self.sparkSession.udtf
+
def range(
self,
start: int,
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 6b03025b614..54b7b312bf5 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -32,6 +32,7 @@ from typing import (
overload,
Optional,
Tuple,
+ Type,
TYPE_CHECKING,
Union,
ValuesView,
@@ -47,6 +48,7 @@ from pyspark.sql.types import ArrayType, DataType, StringType, StructType, _from
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401
+from pyspark.sql.udtf import UserDefinedTableFunction, _create_udtf
# Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264
from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType # noqa: F401
@@ -11451,6 +11453,100 @@ def udf(
return _create_py_udf(f=f, returnType=returnType, useArrow=useArrow)
+def udtf(
+ cls: Optional[Type] = None,
+ *,
+ returnType: Union[StructType, str],
+) -> Union[UserDefinedTableFunction, functools.partial]:
+ """Creates a user defined table function (UDTF).
+
+ .. versionadded:: 3.5.0
+
+ Parameters
+ ----------
+ cls : class
+ the Python user-defined table function handler class.
+ returnType : :class:`pyspark.sql.types.StructType` or str
+ the return type of the user-defined table function. The value can be either a
+ :class:`pyspark.sql.types.StructType` object or a DDL-formatted struct type string.
+
+ Examples
+ --------
+ Implement the UDTF class
+
+ >>> class TestUDTF:
+ ... def eval(self, *args: Any):
+ ... yield "hello", "world"
+
+ Create the UDTF
+
+ >>> from pyspark.sql.functions import udtf
+ >>> test_udtf = udtf(TestUDTF, returnType="c1: string, c2: string")
+
+ Create the UDTF using the decorator
+
+ >>> @udtf(returnType="c1: int, c2: int")
+ ... class PlusOne:
+ ... def eval(self, x: int):
+ ... yield x, x + 1
+
+ Invoke the UDTF
+
+ >>> test_udtf().show()
+ +-----+-----+
+ | c1| c2|
+ +-----+-----+
+ |hello|world|
+ +-----+-----+
+
+ Invoke the UDTF with parameters
+
+ >>> from pyspark.sql.functions import lit
+ >>> PlusOne(lit(1)).show()
+ +---+---+
+ | c1| c2|
+ +---+---+
+ | 1| 2|
+ +---+---+
+
+ Notes
+ -----
+ User-defined table functions (UDTFs) are considered deterministic by default.
+ Use `asNondeterministic()` to mark a function as non-deterministic. E.g.:
+
+ >>> import random
+ >>> class RandomUDTF:
+ ... def eval(self, a: int):
+ ... yield a * int(random.random() * 100),
+ >>> random_udtf = udtf(RandomUDTF, returnType="r: int").asNondeterministic()
+
+ Use "yield" to produce one row for the UDTF result relation as many times
+ as needed. In the context of a lateral join, each such result row will be
+ associated with the most recent input row consumed from the "eval" method.
+ Or, use "return" to produce multiple rows for the UDTF result relation at
+ once.
+
+ >>> class TestUDTF:
+ ... def eval(self, a: int):
+ ... return [(a, a + 1), (a, a + 2)]
+ >>> test_udtf = udtf(TestUDTF, returnType="x: int, y: int")
+
+ User-defined table functions are considered opaque to the optimizer by default.
+ As a result, operations like filters from WHERE clauses or limits from
+ LIMIT/OFFSET clauses that appear after the UDTF call will execute on the
+ UDTF's result relation. By the same token, any relations forwarded as input
+ to UDTFs will plan as full table scans in the absence of any explicit such
+ filtering or other logic explicitly written in a table subquery surrounding the
+ provided input relation.
+
+ User-defined table functions do not accept keyword arguments on the calling side.
+ """
+ if cls is None:
+ return functools.partial(_create_udtf, returnType=returnType)
+ else:
+ return _create_udtf(cls=cls, returnType=returnType)
+
+
def _test() -> None:
import doctest
from pyspark.sql import SparkSession
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index e96dc9cee3f..823164475ea 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -72,6 +72,7 @@ if TYPE_CHECKING:
from pyspark.sql.pandas._typing import ArrayLike, DataFrameLike as PandasDataFrameLike
from pyspark.sql.streaming import StreamingQueryManager
from pyspark.sql.udf import UDFRegistration
+ from pyspark.sql.udtf import UDTFRegistration
__all__ = ["SparkSession"]
@@ -792,6 +793,20 @@ class SparkSession(SparkConversionMixin):
return UDFRegistration(self)
+ @property
+ def udtf(self) -> "UDTFRegistration":
+ """Returns a :class:`UDTFRegistration` for UDTF registration.
+
+ .. versionadded:: 3.5.0
+
+ Returns
+ -------
+ :class:`UDTFRegistration`
+ """
+ from pyspark.sql.udtf import UDTFRegistration
+
+ return UDTFRegistration(self)
+
def range(
self,
start: int,
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 18a7d8f19b4..89384b24e45 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -27,6 +27,7 @@ from collections import defaultdict
from pyspark.errors import (
PySparkAttributeError,
+ PySparkNotImplementedError,
PySparkTypeError,
PySparkException,
PySparkValueError,
@@ -3180,6 +3181,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
rows = [cols] * row_count
self.assertEqual(row_count, self.connect.createDataFrame(data=rows).count())
+ def test_unsupported_udtf(self):
+ with self.assertRaises(PySparkNotImplementedError) as e:
+ self.connect.udtf.register()
+
+ self.check_error(
+ exception=e.exception,
+ error_class="NOT_IMPLEMENTED",
+ message_parameters={"feature": "udtf()"},
+ )
+
def test_unsupported_jvm_attribute(self):
# Unsupported jvm attributes for Spark session.
unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"]
diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py
new file mode 100644
index 00000000000..628f2696b84
--- /dev/null
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -0,0 +1,413 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from typing import Iterator
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark.errors import PythonException, AnalysisException
+from pyspark.sql.functions import lit, udtf
+from pyspark.sql.types import Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class UDTFTestsMixin(ReusedSQLTestCase):
+ def test_simple_udtf(self):
+ class TestUDTF:
+ def eval(self):
+ yield "hello", "world"
+
+ func = udtf(TestUDTF, returnType="c1: string, c2: string")
+ rows = func().collect()
+ self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+ def test_udtf_yield_single_row_col(self):
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a,
+
+ func = udtf(TestUDTF, returnType="a: int")
+ rows = func(lit(1)).collect()
+ self.assertEqual(rows, [Row(a=1)])
+
+ def test_udtf_yield_multi_cols(self):
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a, a + 1
+
+ func = udtf(TestUDTF, returnType="a: int, b: int")
+ rows = func(lit(1)).collect()
+ self.assertEqual(rows, [Row(a=1, b=2)])
+
+ def test_udtf_yield_multi_rows(self):
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a,
+ yield a + 1,
+
+ func = udtf(TestUDTF, returnType="a: int")
+ rows = func(lit(1)).collect()
+ self.assertEqual(rows, [Row(a=1), Row(a=2)])
+
+ def test_udtf_yield_multi_row_col(self):
+ class TestUDTF:
+ def eval(self, a: int, b: int):
+ yield a, b, a + b
+ yield a, b, a - b
+ yield a, b, b - a
+
+ func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+ rows = func(lit(1), lit(2)).collect()
+ self.assertEqual(rows, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, b=2, c=1)])
+
+ def test_udtf_decorator(self):
+ @udtf(returnType="a: int, b: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a, a + 1
+
+ rows = TestUDTF(lit(1)).collect()
+ self.assertEqual(rows, [Row(a=1, b=2)])
+
+ def test_udtf_registration(self):
+ class TestUDTF:
+ def eval(self, a: int, b: int):
+ yield a, b, a + b
+ yield a, b, a - b
+ yield a, b, b - a
+
+ func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+ self.spark.udtf.register("testUDTF", func)
+ df = self.spark.sql("SELECT * FROM testUDTF(1, 2)")
+ self.assertEqual(
+ df.collect(), [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, b=2, c=1)]
+ )
+
+ def test_udtf_with_lateral_join(self):
+ class TestUDTF:
+ def eval(self, a: int, b: int) -> Iterator:
+ yield a, b, a + b
+ yield a, b, a - b
+
+ func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+ self.spark.udtf.register("testUDTF", func)
+ df = self.spark.sql(
+ "SELECT f.* FROM values (0, 1), (1, 2) t(a, b), LATERAL testUDTF(a, b) f"
+ )
+ expected = self.spark.createDataFrame(
+ [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=["a", "b", "c"]
+ )
+ self.assertEqual(df.collect(), expected.collect())
+
+ def test_udtf_eval_with_return_stmt(self):
+ class TestUDTF:
+ def eval(self, a: int, b: int):
+ return [(a, a + 1), (b, b + 1)]
+
+ func = udtf(TestUDTF, returnType="a: int, b: int")
+ rows = func(lit(1), lit(2)).collect()
+ self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)])
+
+ def test_udtf_eval_returning_non_tuple(self):
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a
+
+ func = udtf(TestUDTF, returnType="a: int")
+ # TODO(SPARK-44005): improve this error message
+ with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"):
+ func(lit(1)).collect()
+
+ def test_udtf_eval_returning_non_generator(self):
+ class TestUDTF:
+ def eval(self, a: int):
+ return (a,)
+
+ func = udtf(TestUDTF, returnType="a: int")
+ # TODO(SPARK-44005): improve this error message
+ with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"):
+ func(lit(1)).collect()
+
+ def test_udtf_eval_with_no_return(self):
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ ...
+
+ # TODO(SPARK-43967): Support Python UDTFs with empty return values
+ with self.assertRaisesRegex(PythonException, "TypeError"):
+ TestUDTF(lit(1)).collect()
+
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ return
+
+ with self.assertRaisesRegex(PythonException, "TypeError"):
+ TestUDTF(lit(1)).collect()
+
+ def test_udtf_with_conditional_return(self):
+ class TestUDTF:
+ def eval(self, a: int):
+ if a > 5:
+ yield a,
+
+ func = udtf(TestUDTF, returnType="a: int")
+ self.spark.udtf.register("test_udtf", func)
+ self.assertEqual(
+ self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL test_udtf(id)").collect(),
+ [Row(id=6, a=6), Row(id=7, a=7)],
+ )
+
+ def test_udtf_with_empty_yield(self):
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield
+
+ # TODO(SPARK-43967): Support Python UDTFs with empty return values
+ with self.assertRaisesRegex(Py4JJavaError, "java.lang.NullPointerException"):
+ TestUDTF(lit(1)).collect()
+
+ def test_udtf_with_none_output(self):
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a,
+ yield None,
+
+ self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=1), Row(a=None)])
+ df = self.spark.createDataFrame([(0, 1), (1, 2)], schema=["a", "b"])
+ self.assertEqual(TestUDTF(lit(1)).join(df, "a", "inner").collect(), [Row(a=1, b=2)])
+ self.assertEqual(
+ TestUDTF(lit(1)).join(df, "a", "left").collect(), [Row(a=None, b=None), Row(a=1, b=2)]
+ )
+
+ def test_udtf_with_none_input(self):
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a,
+
+ self.assertEqual(TestUDTF(lit(None)).collect(), [Row(a=None)])
+ self.spark.udtf.register("testUDTF", TestUDTF)
+ df = self.spark.sql("SELECT * FROM testUDTF(null)")
+ self.assertEqual(df.collect(), [Row(a=None)])
+
+ def test_udtf_with_wrong_num_input(self):
+ @udtf(returnType="a: int, b: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a, a + 1
+
+ with self.assertRaisesRegex(
+ PythonException, r"eval\(\) missing 1 required positional argument: 'a'"
+ ):
+ TestUDTF().collect()
+
+ with self.assertRaisesRegex(
+ PythonException, r"eval\(\) takes 2 positional arguments but 3 were given"
+ ):
+ TestUDTF(lit(1), lit(2)).collect()
+
+ def test_udtf_with_wrong_num_output(self):
+ # TODO(SPARK-43968): check this during compile time instead of runtime
+ err_msg = (
+ "java.lang.IllegalStateException: Input row doesn't have expected number of "
+ + "values required by the schema."
+ )
+
+ @udtf(returnType="a: int, b: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a,
+
+ with self.assertRaisesRegex(Py4JJavaError, err_msg):
+ TestUDTF(lit(1)).collect()
+
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a, a + 1
+
+ with self.assertRaisesRegex(Py4JJavaError, err_msg):
+ TestUDTF(lit(1)).collect()
+
+ def test_udtf_init(self):
+ @udtf(returnType="a: int, b: int, c: string")
+ class TestUDTF:
+ def __init__(self):
+ self.key = "test"
+
+ def eval(self, a: int):
+ yield a, a + 1, self.key
+
+ rows = TestUDTF(lit(1)).collect()
+ self.assertEqual(rows, [Row(a=1, b=2, c="test")])
+
+ def test_udtf_terminate(self):
+ @udtf(returnType="key: string, value: float")
+ class TestUDTF:
+ def __init__(self):
+ self._count = 0
+ self._sum = 0
+
+ def eval(self, x: int):
+ self._count += 1
+ self._sum += x
+ yield "input", float(x)
+
+ def terminate(self):
+ yield "count", float(self._count)
+ yield "avg", self._sum / self._count
+
+ self.assertEqual(
+ TestUDTF(lit(1)).collect(),
+ [Row(key="input", value=1), Row(key="count", value=1.0), Row(key="avg", value=1.0)],
+ )
+
+ self.spark.udtf.register("test_udtf", TestUDTF)
+ df = self.spark.sql(
+ "SELECT id, key, value FROM range(0, 10, 1, 2), "
+ "LATERAL test_udtf(id) WHERE key != 'input'"
+ )
+ self.assertEqual(
+ df.collect(),
+ [
+ Row(id=4, key="count", value=5.0),
+ Row(id=4, key="avg", value=2.0),
+ Row(id=9, key="count", value=5.0),
+ Row(id=9, key="avg", value=7.0),
+ ],
+ )
+
+ def test_terminate_with_exceptions(self):
+ @udtf(returnType="a: int, b: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a, a + 1
+
+ def terminate(self):
+ raise ValueError("terminate error")
+
+ with self.assertRaisesRegex(
+ PythonException,
+ "User defined table function encountered an error in the 'terminate' "
+ "method: terminate error",
+ ):
+ TestUDTF(lit(1)).collect()
+
+ def test_udtf_terminate_with_wrong_num_output(self):
+ err_msg = (
+ "java.lang.IllegalStateException: Input row doesn't have expected number of "
+ "values required by the schema."
+ )
+
+ @udtf(returnType="a: int, b: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a, a + 1
+
+ def terminate(self):
+ yield 1, 2, 3
+
+ with self.assertRaisesRegex(Py4JJavaError, err_msg):
+ TestUDTF(lit(1)).show()
+
+ @udtf(returnType="a: int, b: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a, a + 1
+
+ def terminate(self):
+ yield 1,
+
+ with self.assertRaisesRegex(Py4JJavaError, err_msg):
+ TestUDTF(lit(1)).show()
+
+ def test_nondeterministic_udtf(self):
+ import random
+
+ class RandomUDTF:
+ def eval(self, a: int):
+ yield a * int(random.random() * 100),
+
+ random_udtf = udtf(RandomUDTF, returnType="x: int").asNondeterministic()
+ # TODO(SPARK-43966): support non-deterministic UDTFs
+ with self.assertRaisesRegex(AnalysisException, "nondeterministic expressions"):
+ random_udtf(lit(1)).collect()
+
+ def test_udtf_with_nondeterministic_input(self):
+ from pyspark.sql.functions import rand
+
+ @udtf(returnType="x: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ yield a + 1,
+
+ # TODO(SPARK-43966): support non-deterministic UDTFs
+ with self.assertRaisesRegex(AnalysisException, "nondeterministic expressions"):
+ TestUDTF(rand(0) * 100).collect()
+
+ def test_udtf_no_eval(self):
+ @udtf(returnType="a: int, b: int")
+ class TestUDTF:
+ def run(self, a: int):
+ yield a, a + 1
+
+ with self.assertRaisesRegex(
+ PythonException,
+ "Failed to execute the user defined table function because it has not "
+ "implemented the 'eval' method. Please add the 'eval' method and try the "
+ "query again.",
+ ):
+ TestUDTF(lit(1)).collect()
+
+ def test_udtf_with_no_handler_class(self):
+ err_msg = "the function handler must be a class"
+ with self.assertRaisesRegex(TypeError, err_msg):
+
+ @udtf(returnType="a: int")
+ def test_udtf(a: int):
+ yield a,
+
+ def test_udtf(a: int):
+ yield a
+
+ with self.assertRaisesRegex(TypeError, err_msg):
+ udtf(test_udtf, returnType="a: int")
+
+
+class UDTFTests(UDTFTestsMixin, ReusedSQLTestCase):
+ @classmethod
+ def setUpClass(cls):
+ super(UDTFTests, cls).setUpClass()
+ cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false")
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_udtf import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
new file mode 100644
index 00000000000..e1b0825c170
--- /dev/null
+++ b/python/pyspark/sql/udtf.py
@@ -0,0 +1,233 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+User-defined table function related classes and functions
+"""
+import sys
+from typing import Type, TYPE_CHECKING, Optional, Union
+
+from py4j.java_gateway import JavaObject
+
+from pyspark.errors import PySparkTypeError
+from pyspark.sql.column import _to_java_column, _to_seq
+from pyspark.sql.types import StructType, _parse_datatype_string
+from pyspark.sql.udf import _wrap_function
+
+if TYPE_CHECKING:
+ from pyspark.sql._typing import ColumnOrName
+ from pyspark.sql.dataframe import DataFrame
+ from pyspark.sql.session import SparkSession
+
+__all__ = ["UDTFRegistration"]
+
+
+def _create_udtf(
+ cls: Type,
+ returnType: Union[StructType, str],
+ name: Optional[str] = None,
+ deterministic: bool = True,
+) -> "UserDefinedTableFunction":
+ """Create a Python UDTF."""
+ udtf_obj = UserDefinedTableFunction(
+ cls, returnType=returnType, name=name, deterministic=deterministic
+ )
+ return udtf_obj
+
+
+class UserDefinedTableFunction:
+ """
+ User-defined table function in Python
+
+ .. versionadded:: 3.5.0
+
+ Notes
+ -----
+ The constructor of this class is not supposed to be directly called.
+ Use :meth:`pyspark.sql.functions.udtf` to create this instance.
+
+ This API is evolving.
+ """
+
+ def __init__(
+ self,
+ func: Type,
+ returnType: Union[StructType, str],
+ name: Optional[str] = None,
+ deterministic: bool = True,
+ ):
+
+ if not isinstance(func, type):
+ raise PySparkTypeError(
+ f"Invalid user defined table function: the function handler "
+ f"must be a class, but got {type(func).__name__}. Please provide "
+ "a class as the handler."
+ )
+
+ # TODO(SPARK-43968): add more compile time checks for UDTFs
+ self.func = func
+ self._returnType = returnType
+ self._returnType_placeholder: Optional[StructType] = None
+ self._inputTypes_placeholder = None
+ self._judtf_placeholder = None
+ self._name = name or func.__name__
+ self.deterministic = deterministic
+
+ @property
+ def returnType(self) -> StructType:
+ # `_parse_datatype_string` accesses to JVM for parsing a DDL formatted string.
+ # This makes sure this is called after SparkContext is initialized.
+ if self._returnType_placeholder is None:
+ if isinstance(self._returnType, StructType):
+ self._returnType_placeholder = self._returnType
+ else:
+ assert isinstance(self._returnType, str)
+ parsed = _parse_datatype_string(self._returnType)
+ if not isinstance(parsed, StructType):
+ raise PySparkTypeError(
+ f"Invalid return type for the user defined table function "
+ f"'{self._name}': {self._returnType}. The return type of a "
+ f"UDTF must be a 'StructType'. Please ensure the return "
+ "type is a correctly formatted 'StructType' string."
+ )
+ self._returnType_placeholder = parsed
+ return self._returnType_placeholder
+
+ @property
+ def _judtf(self) -> JavaObject:
+ if self._judtf_placeholder is None:
+ self._judtf_placeholder = self._create_judtf(self.func)
+ return self._judtf_placeholder
+
+ def _create_judtf(self, func: Type) -> JavaObject:
+ from pyspark.sql import SparkSession
+
+ spark = SparkSession._getActiveSessionOrCreate()
+ sc = spark.sparkContext
+
+ wrapped_func = _wrap_function(sc, func, self.returnType)
+ jdt = spark._jsparkSession.parseDataType(self.returnType.json())
+ assert sc._jvm is not None
+ judtf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(
+ self._name, wrapped_func, jdt, self.deterministic
+ )
+ return judtf
+
+ def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
+ from pyspark.sql import DataFrame, SparkSession
+
+ spark = SparkSession._getActiveSessionOrCreate()
+ sc = spark.sparkContext
+
+ judtf = self._judtf
+ jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, cols, _to_java_column))
+ return DataFrame(jPythonUDTF, spark)
+
+ def asNondeterministic(self) -> "UserDefinedTableFunction":
+ """
+ Updates UserDefinedTableFunction to nondeterministic.
+ """
+ # Explicitly clean the cache to create a JVM UDTF instance.
+ self._judtf_placeholder = None
+ self.deterministic = False
+ return self
+
+
+class UDTFRegistration:
+ """
+ Wrapper for user-defined table function registration. This instance can be accessed by
+ :attr:`spark.udtf` or :attr:`sqlContext.udtf`.
+
+ .. versionadded:: 3.5.0
+ """
+
+ def __init__(self, sparkSession: "SparkSession"):
+ self.sparkSession = sparkSession
+
+ def register(
+ self,
+ name: str,
+ f: UserDefinedTableFunction,
+ ) -> UserDefinedTableFunction:
+ """Register a Python user-defined table function as a SQL table function.
+
+ .. versionadded:: 3.5.0
+
+ Parameters
+ ----------
+ name : str
+ The name of the user-defined table function in SQL statements.
+ f : function or :meth:`pyspark.sql.functions.udtf`
+ The user-defined table function.
+
+ Returns
+ -------
+ function
+ The registered user-defined table function.
+
+ Notes
+ -----
+ Spark uses the return type of the given user-defined table function as the return
+ type of the registered user-defined function.
+
+ To register a nondeterministic Python table function, users need to first build
+ a nondeterministic user-defined table function and then register it as a SQL function.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import udtf
+ >>> @udtf(returnType="c1: int, c2: int")
+ ... class PlusOne:
+ ... def eval(self, x: int):
+ ... yield x, x + 1
+ >>> _ = spark.udtf.register(name="plus_one", f=PlusOne)
+ >>> spark.sql("SELECT * FROM plus_one(1)").collect()
+ [Row(c1=1, c2=2)]
+
+ Use it with lateral join
+
+ >>> spark.sql("SELECT * FROM VALUES (0, 1), (1, 2) t(x, y), LATERAL plus_one(x)").collect()
+ [Row(x=0, y=1, c1=0, c2=1), Row(x=1, y=2, c1=1, c2=2)]
+ """
+ register_udtf = _create_udtf(
+ f.func,
+ returnType=f.returnType,
+ name=name,
+ deterministic=f.deterministic,
+ )
+ return_udtf = f
+ self.sparkSession._jsparkSession.udtf().registerPython(name, register_udtf._judtf)
+ return return_udtf
+
+
+def _test() -> None:
+ import doctest
+ from pyspark.sql import SparkSession
+ import pyspark.sql.udf
+
+ globs = pyspark.sql.udtf.__dict__.copy()
+ spark = SparkSession.builder.master("local[4]").appName("sql.udtf tests").getOrCreate()
+ globs["spark"] = spark
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.udtf, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
+ )
+ spark.stop()
+ if failure_count:
+ sys.exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 06f0d1dc37f..77d0d548408 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -64,7 +64,7 @@ from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import StructType
from pyspark.util import fail_on_stopiteration, try_simplify_traceback
from pyspark import shuffle
-from pyspark.errors import PySparkRuntimeError
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
pickleSer = CPickleSerializer()
utf8_deserializer = UTF8Deserializer()
@@ -456,6 +456,86 @@ def assign_cols_by_name(runner_conf):
)
+# Read and process a serialized user-defined table function (UDTF) from a socket.
+# It expects the UDTF to be in a specific format and performs various checks to
+# ensure the UDTF is valid. This function also prepares a mapper function for applying
+# the UDTF logic to input rows.
+def read_udtf(pickleSer, infile, eval_type):
+ num_udtfs = read_int(infile)
+ if num_udtfs != 1:
+ raise PySparkValueError(f"Unexpected number of UDTFs. Expected 1 but got {num_udtfs}.")
+
+ # See `PythonUDFRunner.writeUDFs`.
+ num_arg = read_int(infile)
+ arg_offsets = [read_int(infile) for _ in range(num_arg)]
+ num_chained_funcs = read_int(infile)
+ if num_chained_funcs != 1:
+ raise PySparkValueError(
+ f"Unexpected number of chained UDTFs. Expected 1 but got {num_chained_funcs}."
+ )
+
+ handler, return_type = read_command(pickleSer, infile)
+ if not isinstance(handler, type):
+ raise PySparkRuntimeError(
+ f"Invalid UDTF handler type. Expected a class (type 'type'), but "
+ f"got an instance of {type(handler).__name__}."
+ )
+
+ # Instantiate the UDTF class.
+ try:
+ udtf = handler()
+ except Exception as e:
+ raise PySparkRuntimeError(
+ f"User defined table function encountered an error in "
+ f"the '__init__' method: {str(e)}"
+ )
+
+ # Validate the UDTF
+ if not hasattr(udtf, "eval"):
+ raise PySparkRuntimeError(
+ "Failed to execute the user defined table function because it has not "
+ "implemented the 'eval' method. Please add the 'eval' method and try "
+ "the query again."
+ )
+
+ # Wrap the UDTF and convert the results.
+ def wrap_udtf(f, return_type):
+ if return_type.needConversion():
+ toInternal = return_type.toInternal
+ return lambda *a: map(toInternal, f(*a))
+ else:
+ return lambda *a: f(*a)
+
+ eval = wrap_udtf(getattr(udtf, "eval"), return_type)
+
+ if hasattr(udtf, "terminate"):
+ terminate = wrap_udtf(getattr(udtf, "terminate"), return_type)
+ else:
+ terminate = None
+
+ def mapper(a):
+ results = tuple(eval(*[a[o] for o in arg_offsets]))
+ return results
+
+ # Return an iterator of iterators.
+ def func(_, it):
+ try:
+ yield from map(mapper, it)
+ finally:
+ if terminate is not None:
+ try:
+ yield tuple(terminate())
+ except BaseException as e:
+ raise PySparkRuntimeError(
+ f"User defined table function encountered an error in "
+ f"the 'terminate' method: {str(e)}"
+ )
+
+ ser = BatchedSerializer(CPickleSerializer(), 100)
+
+ return func, None, ser, ser
+
+
def read_udfs(pickleSer, infile, eval_type):
runner_conf = {}
@@ -859,6 +939,8 @@ def main(infile, outfile):
eval_type = read_int(infile)
if eval_type == PythonEvalType.NON_UDF:
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
+ elif eval_type == PythonEvalType.SQL_TABLE_UDF:
+ func, profiler, deserializer, serializer = read_udtf(pickleSer, infile, eval_type)
else:
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
index be2b3dbe819..84ee30440b7 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
@@ -48,7 +48,8 @@ public class ExpressionInfo {
"window_funcs", "xml_funcs", "table_funcs", "url_funcs"));
private static final Set<String> validSources =
- new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf", "java_udf"));
+ new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf",
+ "java_udf", "python_udtf"));
public String getClassName() {
return className;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index 8636eb61034..6905bde9c33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -57,8 +57,6 @@ trait PythonFuncExpression extends NonSQLExpression with UserDefinedExpression {
def udfDeterministic: Boolean
def resultId: ExprId
- final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)
-
override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
override def toString: String = s"$name(${children.mkString(", ")})#${resultId.id}$typeSuffix"
@@ -89,6 +87,8 @@ case class PythonUDF(
this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
}
+ final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)
+
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDF =
copy(children = newChildren)
}
@@ -138,10 +138,45 @@ case class PythonUDAF(
this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
}
+ final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)
+
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDAF =
copy(children = newChildren)
}
+abstract class UnevaluableGenerator extends Generator {
+ final override def eval(input: InternalRow): TraversableOnce[InternalRow] =
+ throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
+
+ final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
+}
+
+/**
+ * A serialized version of a Python table-valued function. This is a special expression,
+ * which needs a dedicated physical operator to execute it.
+ */
+case class PythonUDTF(
+ name: String,
+ func: PythonFunction,
+ override val elementSchema: StructType,
+ children: Seq[Expression],
+ udfDeterministic: Boolean,
+ resultId: ExprId = NamedExpression.newExprId)
+ extends UnevaluableGenerator with PythonFuncExpression {
+
+ override def evalType: Int = PythonEvalType.SQL_TABLE_UDF
+
+ override lazy val canonicalized: Expression = {
+ val canonicalizedChildren = children.map(_.canonicalized)
+ // `resultId` can be seen as cosmetic variation in PythonUDTF, as it doesn't affect the result.
+ this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
+ }
+
+ override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDTF =
+ copy(children = newChildren)
+}
+
/**
* A place holder used when printing expressions without debugging information such as the
* result id.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala
index d7e8b8a9610..db89fc4bbbd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala
@@ -48,7 +48,8 @@ object PlanHelper {
plan.isInstanceOf[CollectMetrics] ||
onlyInLateralSubquery(plan)) => e
case e: Generator
- if !plan.isInstanceOf[Generate] => e
+ if !(plan.isInstanceOf[Generate] ||
+ plan.isInstanceOf[BaseEvalPythonUDTF]) => e
}
}
invalidExpressions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index fe5eee481be..f79c360a313 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, PythonUDTF}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
@@ -148,6 +148,21 @@ trait BaseEvalPython extends UnaryNode {
final override val nodePatterns: Seq[TreePattern] = Seq(EVAL_PYTHON_UDF)
}
+trait BaseEvalPythonUDTF extends UnaryNode {
+
+ def udtf: PythonUDTF
+
+ def requiredChildOutput: Seq[Attribute]
+
+ def resultAttrs: Seq[Attribute]
+
+ override def output: Seq[Attribute] = requiredChildOutput ++ resultAttrs
+
+ override def producedAttributes: AttributeSet = AttributeSet(resultAttrs)
+
+ final override val nodePatterns: Seq[TreePattern] = Seq(EVAL_PYTHON_UDTF)
+}
+
/**
* A logical plan that evaluates a [[PythonUDF]]
*/
@@ -171,6 +186,24 @@ case class ArrowEvalPython(
copy(child = newChild)
}
+/**
+ * A logical plan that evaluates a [[PythonUDTF]].
+ *
+ * @param udtf the user-defined Python function
+ * @param requiredChildOutput the required output of the child plan. It's used for omitting data
+ * generation that will be discarded next by a projection.
+ * @param resultAttrs the output schema of the Python UDTF.
+ * @param child the child plan
+ */
+case class BatchEvalPythonUDTF(
+ udtf: PythonUDTF,
+ requiredChildOutput: Seq[Attribute],
+ resultAttrs: Seq[Attribute],
+ child: LogicalPlan) extends BaseEvalPythonUDTF {
+ override protected def withNewChildInternal(newChild: LogicalPlan): BatchEvalPythonUDTF =
+ copy(child = newChild)
+}
+
/**
* A logical plan that adds a new long column with the name `name` that
* increases one by one. This is for 'distributed-sequence' default index
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 230d2b04c4f..11d5cf54df4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -106,6 +106,7 @@ object TreePattern extends Enumeration {
val CTE: Value = Value
val DISTINCT_LIKE: Value = Value
val EVAL_PYTHON_UDF: Value = Value
+ val EVAL_PYTHON_UDTF: Value = Value
val EVENT_TIME_WATERMARK: Value = Value
val EXCEPT: Value = Value
val FILTER: Value = Value
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 66215d05033..642006fb8dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -225,6 +225,8 @@ class SparkSession private(
*/
def udf: UDFRegistration = sessionState.udfRegistration
+ def udtf: UDTFRegistration = sessionState.udtfRegistration
+
/**
* Returns a `StreamingQueryManager` that allows managing all the
* `StreamingQuery`s active on `this`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDTFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDTFRegistration.scala
new file mode 100644
index 00000000000..3330597cb76
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDTFRegistration.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Evolving
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry
+import org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction
+
+/**
+ * Functions for registering user-defined table functions. Use `SparkSession.udtf` to access this.
+ *
+ * @since 3.5.0
+ */
+@Evolving
+class UDTFRegistration private[sql] (tableFunctionRegistry: TableFunctionRegistry)
+ extends Logging {
+
+ protected[sql] def registerPython(name: String, udtf: UserDefinedPythonTableFunction): Unit = {
+ log.debug(
+ s"""
+ | Registering new PythonUDTF:
+ | name: $name
+ | command: ${udtf.func.command}
+ | envVars: ${udtf.func.envVars}
+ | pythonIncludes: ${udtf.func.pythonIncludes}
+ | pythonExec: ${udtf.func.pythonExec}
+ | returnType: ${udtf.returnType}
+ | udfDeterministic: ${udtf.udfDeterministic}
+ """.stripMargin)
+
+ tableFunctionRegistry.createOrReplaceTempFunction(name, udtf.builder, "python_udtf")
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index f05fe9d60fb..70a35ea9115 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
-import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
+import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs}
class SparkOptimizer(
catalogManager: CatalogManager,
@@ -77,6 +77,7 @@ class SparkOptimizer(
// This must be executed after `ExtractPythonUDFFromAggregate` and before `ExtractPythonUDFs`.
ExtractGroupingPythonUDFFromAggregate,
ExtractPythonUDFs,
+ ExtractPythonUDTFs,
// The eval-python node may be between Project/Filter and the scan node, which breaks
// column pruning and filter push-down. Here we rerun the related optimizer rules.
ColumnPruning,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index a414b815018..a4cca1a248b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -747,6 +747,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
ArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil
case BatchEvalPython(udfs, output, child) =>
BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil
+ case BatchEvalPythonUDTF(udtf, requiredChildOutput, resultAttrs, child) =>
+ BatchEvalPythonUDTFExec(udtf, requiredChildOutput, resultAttrs, planLater(child)) :: Nil
case _ =>
Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index ca7ca2e2f80..c8a798d5b70 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -42,39 +42,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
context: TaskContext): Iterator[InternalRow] = {
EvaluatePython.registerPicklers() // register pickler for Row
- val dataTypes = schema.map(_.dataType)
- val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
-
- // enable memo iff we serialize the row with schema (schema and class should be memorized)
- // pyrolite 4.21+ can lookup objects in its cache by value, but `GenericRowWithSchema` objects,
- // that we pass from JVM to Python, don't define their `equals()` to take the type of the
- // values or the schema of the row into account. This causes like
- // `GenericRowWithSchema(Array(1.0, 1.0),
- // StructType(Seq(StructField("_1", DoubleType), StructField("_2", DoubleType))))`
- // and
- // `GenericRowWithSchema(Array(1, 1),
- // StructType(Seq(StructField("_1", IntegerType), StructField("_2", IntegerType))))`
- // to be `equal()` and so we need to disable this feature explicitly (`valueCompare=false`).
- // Please note that cache by reference is still enabled depending on `needConversion`.
- val pickle = new Pickler(/* useMemo = */ needConversion,
- /* valueCompare = */ false)
- // Input iterator to Python: input rows are grouped so we send them in batches to Python.
- // For each row, add it to the queue.
- val inputIterator = iter.map { row =>
- if (needConversion) {
- EvaluatePython.toJava(row, schema)
- } else {
- // fast path for these types that does not need conversion in Python
- val fields = new Array[Any](row.numFields)
- var i = 0
- while (i < row.numFields) {
- val dt = dataTypes(i)
- fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
- i += 1
- }
- fields
- }
- }.grouped(100).map(x => pickle.dumps(x.toArray))
+ // Input iterator to Python.
+ val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema)
// Output iterator for results from Python.
val outputIterator =
@@ -109,3 +78,43 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
override protected def withNewChildInternal(newChild: SparkPlan): BatchEvalPythonExec =
copy(child = newChild)
}
+
+object BatchEvalPythonExec {
+ def getInputIterator(
+ iter: Iterator[InternalRow],
+ schema: StructType): Iterator[Array[Byte]] = {
+ val dataTypes = schema.map(_.dataType)
+ val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
+
+ // enable memo iff we serialize the row with schema (schema and class should be memorized)
+ // pyrolite 4.21+ can lookup objects in its cache by value, but `GenericRowWithSchema` objects,
+ // that we pass from JVM to Python, don't define their `equals()` to take the type of the
+ // values or the schema of the row into account. This causes like
+ // `GenericRowWithSchema(Array(1.0, 1.0),
+ // StructType(Seq(StructField("_1", DoubleType), StructField("_2", DoubleType))))`
+ // and
+ // `GenericRowWithSchema(Array(1, 1),
+ // StructType(Seq(StructField("_1", IntegerType), StructField("_2", IntegerType))))`
+ // to be `equal()` and so we need to disable this feature explicitly (`valueCompare=false`).
+ // Please note that cache by reference is still enabled depending on `needConversion`.
+ val pickle = new Pickler(/* useMemo = */ needConversion,
+ /* valueCompare = */ false)
+ // Input iterator to Python: input rows are grouped so we send them in batches to Python.
+ // For each row, add it to the queue.
+ iter.map { row =>
+ if (needConversion) {
+ EvaluatePython.toJava(row, schema)
+ } else {
+ // fast path for these types that does not need conversion in Python
+ val fields = new Array[Any](row.numFields)
+ var i = 0
+ while (i < row.numFields) {
+ val dt = dataTypes(i)
+ fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
+ i += 1
+ }
+ fields
+ }
+ }.grouped(100).map(x => pickle.dumps(x.toArray))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
new file mode 100644
index 00000000000..a7fdfb9d173
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import net.razorvine.pickle.Unpickler
+
+import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+/**
+ * A physical plan that evaluates a [[PythonUDTF]]. This is similar to [[BatchEvalPythonExec]].
+ *
+ * @param udtf the user-defined Python function
+ * @param requiredChildOutput the required output of the child plan. It's used for omitting data
+ * generation that will be discarded next by a projection.
+ * @param resultAttrs the output schema of the Python UDTF.
+ * @param child the child plan
+ */
+case class BatchEvalPythonUDTFExec(
+ udtf: PythonUDTF,
+ requiredChildOutput: Seq[Attribute],
+ resultAttrs: Seq[Attribute],
+ child: SparkPlan)
+ extends UnaryExecNode with PythonSQLMetrics {
+
+ override def output: Seq[Attribute] = requiredChildOutput ++ resultAttrs
+
+ override def producedAttributes: AttributeSet = AttributeSet(resultAttrs)
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val inputRDD = child.execute().map(_.copy())
+
+ inputRDD.mapPartitions { iter =>
+ val context = TaskContext.get()
+ val contextAwareIterator = new ContextAwareIterator(context, iter)
+
+ // The queue used to buffer input rows so we can drain it to
+ // combine input with output from Python.
+ val queue = HybridRowQueue(context.taskMemoryManager(),
+ new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+ context.addTaskCompletionListener[Unit] { ctx =>
+ queue.close()
+ }
+
+ val inputs = Seq(udtf.children)
+
+ // flatten all the arguments
+ val allInputs = new ArrayBuffer[Expression]
+ val dataTypes = new ArrayBuffer[DataType]
+ val argOffsets = inputs.map { input =>
+ input.map { e =>
+ if (allInputs.exists(_.semanticEquals(e))) {
+ allInputs.indexWhere(_.semanticEquals(e))
+ } else {
+ allInputs += e
+ dataTypes += e.dataType
+ allInputs.length - 1
+ }
+ }.toArray
+ }.toArray
+ val projection = MutableProjection.create(allInputs.toSeq, child.output)
+ projection.initialize(context.partitionId())
+ val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
+ StructField(s"_$i", dt)
+ }.toArray)
+
+ // Add rows to the queue to join later with the result.
+ // Also keep track of the number rows added to the queue.
+ // This is needed to process extra output rows from the `terminate()` call of the UDTF.
+ var count = 0L
+ val projectedRowIter = contextAwareIterator.map { inputRow =>
+ queue.add(inputRow.asInstanceOf[UnsafeRow])
+ count += 1
+ projection(inputRow)
+ }
+
+ val outputRowIterator = evaluate(udtf, argOffsets, projectedRowIter, schema, context)
+
+ val pruneChildForResult: InternalRow => InternalRow =
+ if (child.outputSet == AttributeSet(requiredChildOutput)) {
+ identity
+ } else {
+ UnsafeProjection.create(requiredChildOutput, child.output)
+ }
+
+ val joined = new JoinedRow
+ val resultProj = UnsafeProjection.create(output, output)
+
+ outputRowIterator.flatMap { outputRows =>
+ // If `count` is greater than zero, it means there are remaining input rows in the queue.
+ // In this case, the output rows of the UDTF are joined with the corresponding input row
+ // in the queue.
+ if (count > 0) {
+ val left = queue.remove()
+ count -= 1
+ joined.withLeft(pruneChildForResult(left))
+ }
+ // If `count` is zero, it means all input rows have been consumed. Any additional rows
+ // from the UDTF are from the `terminate()` call. We leave the left side as the last
+ // element of its child output to keep it consistent with the Generate implementation
+ // and Hive UDTFs.
+ outputRows.map(r => resultProj(joined.withRight(r)))
+ }
+ }
+ }
+
+ /**
+ * Evaluates a Python UDTF. It computes the results using the PythonUDFRunner, and returns
+ * an iterator of internal rows for every input row.
+ */
+ private def evaluate(
+ udtf: PythonUDTF,
+ argOffsets: Array[Array[Int]],
+ iter: Iterator[InternalRow],
+ schema: StructType,
+ context: TaskContext): Iterator[Iterator[InternalRow]] = {
+ EvaluatePython.registerPicklers() // register pickler for Row
+
+ // Input iterator to Python.
+ val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema)
+
+ // Output iterator for results from Python.
+ val funcs = Seq(ChainedPythonFunctions(Seq(udtf.func)))
+ val outputIterator =
+ new PythonUDFRunner(funcs, PythonEvalType.SQL_TABLE_UDF, argOffsets, pythonMetrics)
+ .compute(inputIterator, context.partitionId(), context)
+
+ val unpickle = new Unpickler
+
+ // The return type of a UDTF is an array of struct.
+ val resultType = udtf.dataType
+ val fromJava = EvaluatePython.makeFromJava(resultType)
+
+ outputIterator.flatMap { pickedResult =>
+ val unpickledBatch = unpickle.loads(pickedResult)
+ unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
+ }.map { results =>
+ assert(results.getClass.isArray)
+ val res = results.asInstanceOf[Array[_]]
+ pythonMetrics("pythonNumRowsReceived") += res.length
+ fromJava(results).asInstanceOf[GenericArrayData]
+ .array.map(_.asInstanceOf[InternalRow]).toIterator
+ }
+ }
+
+ override protected def withNewChildInternal(newChild: SparkPlan): BatchEvalPythonUDTFExec =
+ copy(child = newChild)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 57c3e1ad88e..f6993eaf6bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -305,3 +305,19 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Extracts PythonUDTFs from operators, rewriting the query plan so that UDTFs can be evaluated.
+ */
+object ExtractPythonUDTFs extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ // A correlated subquery will be rewritten into join later, and will go through this rule
+ // eventually. Here we skip subquery, as Python UDTFs only need to be extracted once.
+ case s: Subquery if s.correlated => plan
+
+ case _ => plan.transformUpWithPruning(_.containsPattern(GENERATE)) {
+ case g @ Generate(func: PythonUDTF, _, _, _, _, child) =>
+ BatchEvalPythonUDTF(func, g.requiredChildOutput, g.generatorOutput, child)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index bc76eaed04b..08c7c8730d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -18,9 +18,10 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
-import org.apache.spark.sql.Column
-import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDAF, PythonUDF}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDAF, PythonUDF, PythonUDTF}
+import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* A user-defined Python function. This is used by the Python API.
@@ -55,3 +56,36 @@ case class UserDefinedPythonFunction(
}
}
}
+
+/**
+ * A user-defined Python table function. This is used by the Python API.
+ */
+case class UserDefinedPythonTableFunction(
+ name: String,
+ func: PythonFunction,
+ returnType: StructType,
+ udfDeterministic: Boolean) {
+
+ def builder(e: Seq[Expression]): LogicalPlan = {
+ val udtf = PythonUDTF(
+ name = name,
+ func = func,
+ elementSchema = returnType,
+ children = e,
+ udfDeterministic = udfDeterministic)
+ Generate(
+ udtf,
+ unrequiredChildIndex = Nil,
+ outer = false,
+ qualifier = None,
+ generatorOutput = Nil,
+ child = OneRowRelation()
+ )
+ }
+
+ /** Returns a [[DataFrame]] that will evaluate to calling this UDTF with the given input. */
+ def apply(session: SparkSession, exprs: Column*): DataFrame = {
+ val udtf = builder(exprs.map(_.expr))
+ Dataset.ofRows(session, udtf)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 460e5b68ff8..bc4b308f75d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -175,6 +175,8 @@ abstract class BaseSessionStateBuilder(
*/
protected def udfRegistration: UDFRegistration = new UDFRegistration(functionRegistry)
+ protected def udtfRegistration: UDTFRegistration = new UDTFRegistration(tableFunctionRegistry)
+
/**
* Logical query plan analyzer for resolving unresolved attributes and relations.
*
@@ -365,6 +367,7 @@ abstract class BaseSessionStateBuilder(
functionRegistry,
tableFunctionRegistry,
udfRegistration,
+ udtfRegistration,
() => catalog,
sqlParser,
() => analyzer,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index eb0b71d155b..177a25b45fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -46,6 +46,8 @@ import org.apache.spark.util.{DependencyUtils, Utils}
* @param experimentalMethods Interface to add custom planning strategies and optimizers.
* @param functionRegistry Internal catalog for managing functions registered by the user.
* @param udfRegistration Interface exposed to the user for registering user-defined functions.
+ * @param udtfRegistration Interface exposed to the user for registering user-defined
+ * table functions.
* @param catalogBuilder a function to create an internal catalog for managing table and database
* states.
* @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
@@ -69,6 +71,7 @@ private[sql] class SessionState(
val functionRegistry: FunctionRegistry,
val tableFunctionRegistry: TableFunctionRegistry,
val udfRegistration: UDFRegistration,
+ val udtfRegistration: UDTFRegistration,
catalogBuilder: () => SessionCatalog,
val sqlParser: ParserInterface,
analyzerBuilder: () => Analyzer,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index c95b5f7c27f..f3d8e883f7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -30,9 +30,9 @@ import org.apache.spark.api.python.{PythonBroadcast, PythonEvalType, PythonFunct
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, PythonUDF}
import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
+import org.apache.spark.sql.execution.python.{UserDefinedPythonFunction, UserDefinedPythonTableFunction}
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
-import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType}
+import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType, StructType}
/**
* This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF,
@@ -191,6 +191,32 @@ object IntegratedUDFTestUtils extends SQLHelper {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}
+ private def createPythonUDTF(funcName: String, pythonScript: String): Array[Byte] = {
+ if (!shouldTestPythonUDFs) {
+ throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
+ }
+ var binaryPandasFunc: Array[Byte] = null
+ withTempPath { codePath =>
+ Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8))
+ withTempPath { path =>
+ Process(
+ Seq(
+ pythonExec,
+ "-c",
+ "from pyspark.serializers import CloudPickleSerializer; " +
+ s"f = open('$path', 'wb');" +
+ s"exec(open('$codePath', 'r').read());" +
+ "f.write(CloudPickleSerializer().dumps(" +
+ s"($funcName, returnType)))"),
+ None,
+ "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
+ binaryPandasFunc = Files.readAllBytes(path.toPath)
+ }
+ }
+ assert(binaryPandasFunc != null)
+ binaryPandasFunc
+ }
+
private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) {
var binaryPandasFunc: Array[Byte] = null
withTempPath { path =>
@@ -360,6 +386,25 @@ object IntegratedUDFTestUtils extends SQLHelper {
val prettyName: String = "Regular Python UDF"
}
+ def createUserDefinedPythonTableFunction(
+ name: String,
+ pythonScript: String,
+ returnType: StructType,
+ deterministic: Boolean = true): UserDefinedPythonTableFunction = {
+ UserDefinedPythonTableFunction(
+ name = name,
+ func = SimplePythonFunction(
+ command = createPythonUDTF(name, pythonScript),
+ envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
+ pythonIncludes = List.empty[String].asJava,
+ pythonExec = pythonExec,
+ pythonVer = pythonVer,
+ broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
+ accumulator = null),
+ returnType = returnType,
+ udfDeterministic = deterministic)
+ }
+
/**
* A Scalar Pandas UDF that takes one column, casts into string, executes the
* Python native function, and casts back to the type of input column.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
new file mode 100644
index 00000000000..fa4d80c331a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest, Row}
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StructType
+
+class PythonUDTFSuite extends QueryTest with SharedSparkSession {
+
+ import testImplicits._
+
+ import IntegratedUDFTestUtils._
+
+ private val pythonScript: String =
+ """
+ |from pyspark.sql.types import StructType, StructField, IntegerType
+ |returnType = StructType([
+ | StructField("a", IntegerType()),
+ | StructField("b", IntegerType()),
+ | StructField("c", IntegerType()),
+ |])
+ |class SimpleUDTF:
+ | def eval(self, a: int, b: int):
+ | yield a, b, a + b
+ | yield a, b, a - b
+ | yield a, b, b - a
+ |""".stripMargin
+
+ private val returnType: StructType = StructType.fromDDL("a int, b int, c int")
+
+ private val pythonUDTF: UserDefinedPythonTableFunction =
+ createUserDefinedPythonTableFunction("SimpleUDTF", pythonScript, returnType)
+
+ test("Simple PythonUDTF") {
+ assume(shouldTestPythonUDFs)
+ val df = pythonUDTF(spark, lit(1), lit(2))
+ checkAnswer(df, Seq(Row(1, 2, -1), Row(1, 2, 1), Row(1, 2, 3)))
+ }
+
+ test("PythonUDTF with lateral join") {
+ assume(shouldTestPythonUDFs)
+ withTempView("t") {
+ spark.udtf.registerPython("testUDTF", pythonUDTF)
+ Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t")
+ checkAnswer(
+ sql("SELECT f.* FROM t, LATERAL testUDTF(a, b) f"),
+ sql("SELECT * FROM t, LATERAL explode(array(a + b, a - b, b - a)) t(c)"))
+ }
+ }
+
+ test("PythonUDTF in correlated subquery") {
+ assume(shouldTestPythonUDFs)
+ withTempView("t") {
+ spark.udtf.registerPython("testUDTF", pythonUDTF)
+ Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t")
+ checkAnswer(
+ sql("SELECT (SELECT sum(f.b) AS r FROM testUDTF(1, 2) f WHERE f.a = t.a) FROM t"),
+ Seq(Row(6), Row(null)))
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org