You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/13 02:17:39 UTC

[spark] branch master updated: [SPARK-43798][SQL][PYTHON] Support Python user-defined table functions

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 37ab190dc5b [SPARK-43798][SQL][PYTHON] Support Python user-defined table functions
37ab190dc5b is described below

commit 37ab190dc5bfa59b4e06af9551c35ab179a05733
Author: allisonwang-db <al...@databricks.com>
AuthorDate: Tue Jun 13 11:17:10 2023 +0900

    [SPARK-43798][SQL][PYTHON] Support Python user-defined table functions
    
    ### What changes were proposed in this pull request?
    
    This PR adds the initial support for Python user-defined table functions. It allows users to create UDTFs in PySpark and use them in PySpark and SQL.
    
    Here are examples of creating and using Python UDTFs:
    ```python
    # Implement the UDTF class
    class TestUDTF:
      def __init__(self):
        ...
    
      def eval(self, *args):
        yield "hello", "world"
    
      def terminate(self):
        ...
    
    # Create the UDTF
    from pyspark.sql.functions import udtf
    
    test_udtf = udtf(TestUDTF, returnType="c1: string, c2: string")
    
    # Invoke the UDTF
    test_udtf().show()
    +-----+-----+
    |   c1|   c2|
    +-----+-----+
    |hello|world|
    +-----+-----+
    
    # Register the UDTF
    spark.udtf.register(name="test_udtf", f=test_udtf)
    
    # Invoke the UDTF in SQL
    spark.sql("SELECT * FROM test_udtf()").show()
    +-----+-----+
    |   c1|   c2|
    +-----+-----+
    |hello|world|
    +-----+-----+
    ```
    
    Please note that this is the initial PR, and there will be subsequent follow-up work to make it more user-friendly and performant.
    
    ### Why are the changes needed?
    
    To support another type of user-defined function in PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. After this PR, users can create Python user-defined table functions.
    
    ### How was this patch tested?
    
    New unit tests.
    
    Closes #41316 from allisonwang-db/spark-43798-py-udtf.
    
    Authored-by: allisonwang-db <al...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../CheckConnectJvmClientCompatibility.scala       |   1 +
 .../org/apache/spark/api/python/PythonRunner.scala |   3 +
 dev/sparktestsupport/modules.py                    |   2 +
 .../source/reference/pyspark.sql/core_classes.rst  |   2 +
 .../source/reference/pyspark.sql/functions.rst     |   1 +
 python/docs/source/reference/pyspark.sql/index.rst |   1 +
 .../source/reference/pyspark.sql/spark_session.rst |   1 +
 .../pyspark.sql/{core_classes.rst => udtf.rst}     |  26 +-
 python/pyspark/rdd.py                              |  10 +-
 python/pyspark/sql/__init__.py                     |   3 +-
 python/pyspark/sql/_typing.pyi                     |   1 +
 python/pyspark/sql/connect/session.py              |   2 +-
 python/pyspark/sql/context.py                      |  13 +
 python/pyspark/sql/functions.py                    |  96 +++++
 python/pyspark/sql/session.py                      |  15 +
 .../sql/tests/connect/test_connect_basic.py        |  11 +
 python/pyspark/sql/tests/test_udtf.py              | 413 +++++++++++++++++++++
 python/pyspark/sql/udtf.py                         | 233 ++++++++++++
 python/pyspark/worker.py                           |  84 ++++-
 .../sql/catalyst/expressions/ExpressionInfo.java   |   3 +-
 .../spark/sql/catalyst/expressions/PythonUDF.scala |  39 +-
 .../sql/catalyst/plans/logical/PlanHelper.scala    |   3 +-
 .../plans/logical/pythonLogicalOperators.scala     |  35 +-
 .../spark/sql/catalyst/trees/TreePatterns.scala    |   1 +
 .../scala/org/apache/spark/sql/SparkSession.scala  |   2 +
 .../org/apache/spark/sql/UDTFRegistration.scala    |  49 +++
 .../spark/sql/execution/SparkOptimizer.scala       |   3 +-
 .../spark/sql/execution/SparkStrategies.scala      |   2 +
 .../sql/execution/python/BatchEvalPythonExec.scala |  75 ++--
 .../execution/python/BatchEvalPythonUDTFExec.scala | 175 +++++++++
 .../sql/execution/python/ExtractPythonUDFs.scala   |  16 +
 .../python/UserDefinedPythonFunction.scala         |  40 +-
 .../sql/internal/BaseSessionStateBuilder.scala     |   3 +
 .../apache/spark/sql/internal/SessionState.scala   |   3 +
 .../apache/spark/sql/IntegratedUDFTestUtils.scala  |  49 ++-
 .../sql/execution/python/PythonUDTFSuite.scala     |  78 ++++
 36 files changed, 1427 insertions(+), 67 deletions(-)

diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 5e3401ebf50..7a9a889706d 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -225,6 +225,7 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udf"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.streams"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"),
       ProblemFilters.exclude[Problem](
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 912e76005f0..e5c42c721fe 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -56,6 +56,8 @@ private[spark] object PythonEvalType {
   val SQL_MAP_ARROW_ITER_UDF = 207
   val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
 
+  val SQL_TABLE_UDF = 300
+
   def toString(pythonEvalType: Int): String = pythonEvalType match {
     case NON_UDF => "NON_UDF"
     case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
@@ -69,6 +71,7 @@ private[spark] object PythonEvalType {
     case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
     case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
     case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
+    case SQL_TABLE_UDF => "SQL_TABLE_UDF"
   }
 }
 
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 92d4e2b7fbf..82293cbabf0 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -456,6 +456,7 @@ pyspark_sql = Module(
         "pyspark.sql.streaming.readwriter",
         "pyspark.sql.streaming.listener",
         "pyspark.sql.udf",
+        "pyspark.sql.udtf",
         "pyspark.sql.window",
         "pyspark.sql.avro.functions",
         "pyspark.sql.protobuf.functions",
@@ -501,6 +502,7 @@ pyspark_sql = Module(
         "pyspark.sql.tests.test_types",
         "pyspark.sql.tests.test_udf",
         "pyspark.sql.tests.test_udf_profiler",
+        "pyspark.sql.tests.test_udtf",
         "pyspark.sql.tests.test_utils",
     ],
 )
diff --git a/python/docs/source/reference/pyspark.sql/core_classes.rst b/python/docs/source/reference/pyspark.sql/core_classes.rst
index 90c5c412797..3cf19686cdd 100644
--- a/python/docs/source/reference/pyspark.sql/core_classes.rst
+++ b/python/docs/source/reference/pyspark.sql/core_classes.rst
@@ -39,4 +39,6 @@ Core Classes
     DataFrameWriter
     DataFrameWriterV2
     UDFRegistration
+    UDTFRegistration
     udf.UserDefinedFunction
+    udtf.UserDefinedTableFunction
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst
index 186ee0dce8d..fb39ded16c6 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -374,6 +374,7 @@ UDF
     call_udf
     pandas_udf
     udf
+    udtf
     unwrap_udt
 
 Misc Functions
diff --git a/python/docs/source/reference/pyspark.sql/index.rst b/python/docs/source/reference/pyspark.sql/index.rst
index fc4569486a7..233c8b238a6 100644
--- a/python/docs/source/reference/pyspark.sql/index.rst
+++ b/python/docs/source/reference/pyspark.sql/index.rst
@@ -40,4 +40,5 @@ This page gives an overview of all public Spark SQL API.
     avro
     observation
     udf
+    udtf
     protobuf
diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst
index 15724306d75..edd5e746161 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -51,4 +51,5 @@ See also :class:`SparkSession`.
     SparkSession.streams
     SparkSession.table
     SparkSession.udf
+    SparkSession.udtf
     SparkSession.version
diff --git a/python/docs/source/reference/pyspark.sql/core_classes.rst b/python/docs/source/reference/pyspark.sql/udtf.rst
similarity index 72%
copy from python/docs/source/reference/pyspark.sql/core_classes.rst
copy to python/docs/source/reference/pyspark.sql/udtf.rst
index 90c5c412797..c251e101bbb 100644
--- a/python/docs/source/reference/pyspark.sql/core_classes.rst
+++ b/python/docs/source/reference/pyspark.sql/udtf.rst
@@ -16,27 +16,15 @@
     under the License.
 
 
-============
-Core Classes
-============
+====
+UDTF
+====
+
 .. currentmodule:: pyspark.sql
 
 .. autosummary::
     :toctree: api/
 
-    SparkSession
-    Catalog
-    DataFrame
-    Column
-    Observation
-    Row
-    GroupedData
-    PandasCogroupedOps
-    DataFrameNaFunctions
-    DataFrameStatFunctions
-    Window
-    DataFrameReader
-    DataFrameWriter
-    DataFrameWriterV2
-    UDFRegistration
-    udf.UserDefinedFunction
+    udtf.UserDefinedTableFunction.asNondeterministic
+    udtf.UserDefinedTableFunction.returnType
+    UDTFRegistration.register
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index e6ef7f6108e..37c06561f72 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -110,7 +110,13 @@ if TYPE_CHECKING:
     )
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.types import AtomicType, StructType
-    from pyspark.sql._typing import AtomicValue, RowLike, SQLArrowBatchedUDFType, SQLBatchedUDFType
+    from pyspark.sql._typing import (
+        AtomicValue,
+        RowLike,
+        SQLArrowBatchedUDFType,
+        SQLBatchedUDFType,
+        SQLTableUDFType,
+    )
 
     from py4j.java_gateway import JavaObject
     from py4j.java_collections import JavaArray
@@ -152,6 +158,8 @@ class PythonEvalType:
     SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
     SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208
 
+    SQL_TABLE_UDF: "SQLTableUDFType" = 300
+
 
 def portable_hash(x: Hashable) -> int:
     """
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 22149e8adb8..d0d69488fa4 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -40,7 +40,7 @@ Important classes of Spark SQL and DataFrames:
       For working with window functions.
 """
 from pyspark.sql.types import Row
-from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration
+from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration, UDTFRegistration
 from pyspark.sql.session import SparkSession
 from pyspark.sql.column import Column
 from pyspark.sql.catalog import Catalog
@@ -57,6 +57,7 @@ __all__ = [
     "SQLContext",
     "HiveContext",
     "UDFRegistration",
+    "UDTFRegistration",
     "DataFrame",
     "GroupedData",
     "Column",
diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi
index aafd916dbb7..7b5c206c98f 100644
--- a/python/pyspark/sql/_typing.pyi
+++ b/python/pyspark/sql/_typing.pyi
@@ -58,6 +58,7 @@ RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row)
 
 SQLBatchedUDFType = Literal[100]
 SQLArrowBatchedUDFType = Literal[101]
+SQLTableUDFType = Literal[300]
 
 class SupportsOpen(Protocol):
     def open(self, partition_id: int, epoch_id: int) -> bool: ...
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 2b35ca3d7ea..ba0d8caaeca 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -592,7 +592,7 @@ class SparkSession:
             raise PySparkAttributeError(
                 error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": name}
             )
-        elif name in ["newSession", "sparkContext"]:
+        elif name in ["newSession", "sparkContext", "udtf"]:
             raise PySparkNotImplementedError(
                 error_class="NOT_IMPLEMENTED", message_parameters={"feature": f"{name}()"}
             )
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 99f97977ccc..817c3b97337 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -41,6 +41,7 @@ from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.readwriter import DataFrameReader
 from pyspark.sql.streaming import DataStreamReader
 from pyspark.sql.udf import UDFRegistration  # noqa: F401
+from pyspark.sql.udtf import UDTFRegistration
 from pyspark.errors.exceptions.captured import install_exception_handler
 from pyspark.context import SparkContext
 from pyspark.rdd import RDD
@@ -228,6 +229,18 @@ class SQLContext:
         """
         return self.sparkSession.udf
 
+    @property
+    def udtf(self) -> UDTFRegistration:
+        """Returns a :class:`UDTFRegistration` for UDTF registration.
+
+        .. versionadded:: 3.5.0
+
+        Returns
+        -------
+        :class:`UDTFRegistration`
+        """
+        return self.sparkSession.udtf
+
     def range(
         self,
         start: int,
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 6b03025b614..54b7b312bf5 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -32,6 +32,7 @@ from typing import (
     overload,
     Optional,
     Tuple,
+    Type,
     TYPE_CHECKING,
     Union,
     ValuesView,
@@ -47,6 +48,7 @@ from pyspark.sql.types import ArrayType, DataType, StringType, StructType, _from
 
 # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
 from pyspark.sql.udf import UserDefinedFunction, _create_py_udf  # noqa: F401
+from pyspark.sql.udtf import UserDefinedTableFunction, _create_udtf
 
 # Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264
 from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType  # noqa: F401
@@ -11451,6 +11453,100 @@ def udf(
         return _create_py_udf(f=f, returnType=returnType, useArrow=useArrow)
 
 
+def udtf(
+    cls: Optional[Type] = None,
+    *,
+    returnType: Union[StructType, str],
+) -> Union[UserDefinedTableFunction, functools.partial]:
+    """Creates a user defined table function (UDTF).
+
+    .. versionadded:: 3.5.0
+
+    Parameters
+    ----------
+    cls : class
+        the Python user-defined table function handler class.
+    returnType : :class:`pyspark.sql.types.StructType` or str
+        the return type of the user-defined table function. The value can be either a
+        :class:`pyspark.sql.types.StructType` object or a DDL-formatted struct type string.
+
+    Examples
+    --------
+    Implement the UDTF class
+
+    >>> class TestUDTF:
+    ...     def eval(self, *args: Any):
+    ...         yield "hello", "world"
+
+    Create the UDTF
+
+    >>> from pyspark.sql.functions import udtf
+    >>> test_udtf = udtf(TestUDTF, returnType="c1: string, c2: string")
+
+    Create the UDTF using the decorator
+
+    >>> @udtf(returnType="c1: int, c2: int")
+    ... class PlusOne:
+    ...     def eval(self, x: int):
+    ...         yield x, x + 1
+
+    Invoke the UDTF
+
+    >>> test_udtf().show()
+    +-----+-----+
+    |   c1|   c2|
+    +-----+-----+
+    |hello|world|
+    +-----+-----+
+
+    Invoke the UDTF with parameters
+
+    >>> from pyspark.sql.functions import lit
+    >>> PlusOne(lit(1)).show()
+    +---+---+
+    | c1| c2|
+    +---+---+
+    |  1|  2|
+    +---+---+
+
+    Notes
+    -----
+    User-defined table functions (UDTFs) are considered deterministic by default.
+    Use `asNondeterministic()` to mark a function as non-deterministic. E.g.:
+
+    >>> import random
+    >>> class RandomUDTF:
+    ...     def eval(self, a: int):
+    ...         yield a * int(random.random() * 100),
+    >>> random_udtf = udtf(RandomUDTF, returnType="r: int").asNondeterministic()
+
+    Use "yield" to produce one row for the UDTF result relation as many times
+    as needed. In the context of a lateral join, each such result row will be
+    associated with the most recent input row consumed from the "eval" method.
+    Or, use "return" to produce multiple rows for the UDTF result relation at
+    once.
+
+    >>> class TestUDTF:
+    ...     def eval(self, a: int):
+    ...         return [(a, a + 1), (a, a + 2)]
+    >>> test_udtf = udtf(TestUDTF, returnType="x: int, y: int")
+
+    User-defined table functions are considered opaque to the optimizer by default.
+    As a result, operations like filters from WHERE clauses or limits from
+    LIMIT/OFFSET clauses that appear after the UDTF call will execute on the
+    UDTF's result relation. By the same token, any relations forwarded as input
+    to UDTFs will plan as full table scans in the absence of any explicit such
+    filtering or other logic explicitly written in a table subquery surrounding the
+    provided input relation.
+
+    User-defined table functions do not accept keyword arguments on the calling side.
+    """
+    if cls is None:
+        return functools.partial(_create_udtf, returnType=returnType)
+    else:
+        return _create_udtf(cls=cls, returnType=returnType)
+
+
 def _test() -> None:
     import doctest
     from pyspark.sql import SparkSession
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index e96dc9cee3f..823164475ea 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -72,6 +72,7 @@ if TYPE_CHECKING:
     from pyspark.sql.pandas._typing import ArrayLike, DataFrameLike as PandasDataFrameLike
     from pyspark.sql.streaming import StreamingQueryManager
     from pyspark.sql.udf import UDFRegistration
+    from pyspark.sql.udtf import UDTFRegistration
 
 
 __all__ = ["SparkSession"]
@@ -792,6 +793,20 @@ class SparkSession(SparkConversionMixin):
 
         return UDFRegistration(self)
 
+    @property
+    def udtf(self) -> "UDTFRegistration":
+        """Returns a :class:`UDTFRegistration` for UDTF registration.
+
+        .. versionadded:: 3.5.0
+
+        Returns
+        -------
+        :class:`UDTFRegistration`
+        """
+        from pyspark.sql.udtf import UDTFRegistration
+
+        return UDTFRegistration(self)
+
     def range(
         self,
         start: int,
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 18a7d8f19b4..89384b24e45 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -27,6 +27,7 @@ from collections import defaultdict
 
 from pyspark.errors import (
     PySparkAttributeError,
+    PySparkNotImplementedError,
     PySparkTypeError,
     PySparkException,
     PySparkValueError,
@@ -3180,6 +3181,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         rows = [cols] * row_count
         self.assertEqual(row_count, self.connect.createDataFrame(data=rows).count())
 
+    def test_unsupported_udtf(self):
+        with self.assertRaises(PySparkNotImplementedError) as e:
+            self.connect.udtf.register()
+
+        self.check_error(
+            exception=e.exception,
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "udtf()"},
+        )
+
     def test_unsupported_jvm_attribute(self):
         # Unsupported jvm attributes for Spark session.
         unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"]
diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py
new file mode 100644
index 00000000000..628f2696b84
--- /dev/null
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -0,0 +1,413 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from typing import Iterator
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark.errors import PythonException, AnalysisException
+from pyspark.sql.functions import lit, udtf
+from pyspark.sql.types import Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class UDTFTestsMixin(ReusedSQLTestCase):
+    def test_simple_udtf(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF, returnType="c1: string, c2: string")
+        rows = func().collect()
+        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+    def test_udtf_yield_single_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1)])
+
+    def test_udtf_yield_multi_cols(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_yield_multi_rows(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+                yield a + 1,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1), Row(a=2)])
+
+    def test_udtf_yield_multi_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, b=2, c=1)])
+
+    def test_udtf_decorator(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        rows = TestUDTF(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_registration(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql("SELECT * FROM testUDTF(1, 2)")
+        self.assertEqual(
+            df.collect(), [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, b=2, c=1)]
+        )
+
+    def test_udtf_with_lateral_join(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int) -> Iterator:
+                yield a, b, a + b
+                yield a, b, a - b
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql(
+            "SELECT f.* FROM values (0, 1), (1, 2) t(a, b), LATERAL testUDTF(a, b) f"
+        )
+        expected = self.spark.createDataFrame(
+            [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=["a", "b", "c"]
+        )
+        self.assertEqual(df.collect(), expected.collect())
+
+    def test_udtf_eval_with_return_stmt(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                return [(a, a + 1), (b, b + 1)]
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)])
+
+    def test_udtf_eval_returning_non_tuple(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a
+
+        func = udtf(TestUDTF, returnType="a: int")
+        # TODO(SPARK-44005): improve this error message
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"):
+            func(lit(1)).collect()
+
+    def test_udtf_eval_returning_non_generator(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                return (a,)
+
+        func = udtf(TestUDTF, returnType="a: int")
+        # TODO(SPARK-44005): improve this error message
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"):
+            func(lit(1)).collect()
+
+    def test_udtf_eval_with_no_return(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                ...
+
+        # TODO(SPARK-43967): Support Python UDTFs with empty return values
+        with self.assertRaisesRegex(PythonException, "TypeError"):
+            TestUDTF(lit(1)).collect()
+
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                return
+
+        with self.assertRaisesRegex(PythonException, "TypeError"):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_conditional_return(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                if a > 5:
+                    yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+        self.assertEqual(
+            self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL test_udtf(id)").collect(),
+            [Row(id=6, a=6), Row(id=7, a=7)],
+        )
+
+    def test_udtf_with_empty_yield(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield
+
+        # TODO(SPARK-43967): Support Python UDTFs with empty return values
+        with self.assertRaisesRegex(Py4JJavaError, "java.lang.NullPointerException"):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_none_output(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+                yield None,
+
+        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=1), Row(a=None)])
+        df = self.spark.createDataFrame([(0, 1), (1, 2)], schema=["a", "b"])
+        self.assertEqual(TestUDTF(lit(1)).join(df, "a", "inner").collect(), [Row(a=1, b=2)])
+        self.assertEqual(
+            TestUDTF(lit(1)).join(df, "a", "left").collect(), [Row(a=None, b=None), Row(a=1, b=2)]
+        )
+
+    def test_udtf_with_none_input(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        self.assertEqual(TestUDTF(lit(None)).collect(), [Row(a=None)])
+        self.spark.udtf.register("testUDTF", TestUDTF)
+        df = self.spark.sql("SELECT * FROM testUDTF(null)")
+        self.assertEqual(df.collect(), [Row(a=None)])
+
+    def test_udtf_with_wrong_num_input(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        with self.assertRaisesRegex(
+            PythonException, r"eval\(\) missing 1 required positional argument: 'a'"
+        ):
+            TestUDTF().collect()
+
+        with self.assertRaisesRegex(
+            PythonException, r"eval\(\) takes 2 positional arguments but 3 were given"
+        ):
+            TestUDTF(lit(1), lit(2)).collect()
+
+    def test_udtf_with_wrong_num_output(self):
+        # TODO(SPARK-43968): check this during compile time instead of runtime
+        err_msg = (
+            "java.lang.IllegalStateException: Input row doesn't have expected number of "
+            + "values required by the schema."
+        )
+
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+            TestUDTF(lit(1)).collect()
+
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_init(self):
+        @udtf(returnType="a: int, b: int, c: string")
+        class TestUDTF:
+            def __init__(self):
+                self.key = "test"
+
+            def eval(self, a: int):
+                yield a, a + 1, self.key
+
+        rows = TestUDTF(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2, c="test")])
+
+    def test_udtf_terminate(self):
+        @udtf(returnType="key: string, value: float")
+        class TestUDTF:
+            def __init__(self):
+                self._count = 0
+                self._sum = 0
+
+            def eval(self, x: int):
+                self._count += 1
+                self._sum += x
+                yield "input", float(x)
+
+            def terminate(self):
+                yield "count", float(self._count)
+                yield "avg", self._sum / self._count
+
+        self.assertEqual(
+            TestUDTF(lit(1)).collect(),
+            [Row(key="input", value=1), Row(key="count", value=1.0), Row(key="avg", value=1.0)],
+        )
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+        df = self.spark.sql(
+            "SELECT id, key, value FROM range(0, 10, 1, 2), "
+            "LATERAL test_udtf(id) WHERE key != 'input'"
+        )
+        self.assertEqual(
+            df.collect(),
+            [
+                Row(id=4, key="count", value=5.0),
+                Row(id=4, key="avg", value=2.0),
+                Row(id=9, key="count", value=5.0),
+                Row(id=9, key="avg", value=7.0),
+            ],
+        )
+
+    def test_terminate_with_exceptions(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+            def terminate(self):
+                raise ValueError("terminate error")
+
+        with self.assertRaisesRegex(
+            PythonException,
+            "User defined table function encountered an error in the 'terminate' "
+            "method: terminate error",
+        ):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_terminate_with_wrong_num_output(self):
+        err_msg = (
+            "java.lang.IllegalStateException: Input row doesn't have expected number of "
+            "values required by the schema."
+        )
+
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+            def terminate(self):
+                yield 1, 2, 3
+
+        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+            TestUDTF(lit(1)).show()
+
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+            def terminate(self):
+                yield 1,
+
+        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+            TestUDTF(lit(1)).show()
+
+    def test_nondeterministic_udtf(self):
+        import random
+
+        class RandomUDTF:
+            def eval(self, a: int):
+                yield a * int(random.random() * 100),
+
+        random_udtf = udtf(RandomUDTF, returnType="x: int").asNondeterministic()
+        # TODO(SPARK-43966): support non-deterministic UDTFs
+        with self.assertRaisesRegex(AnalysisException, "nondeterministic expressions"):
+            random_udtf(lit(1)).collect()
+
+    def test_udtf_with_nondeterministic_input(self):
+        from pyspark.sql.functions import rand
+
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a + 1,
+
+        # TODO(SPARK-43966): support non-deterministic UDTFs
+        with self.assertRaisesRegex(AnalysisException, "nondeterministic expressions"):
+            TestUDTF(rand(0) * 100).collect()
+
+    def test_udtf_no_eval(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def run(self, a: int):
+                yield a, a + 1
+
+        with self.assertRaisesRegex(
+            PythonException,
+            "Failed to execute the user defined table function because it has not "
+            "implemented the 'eval' method. Please add the 'eval' method and try the "
+            "query again.",
+        ):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_no_handler_class(self):
+        err_msg = "the function handler must be a class"
+        with self.assertRaisesRegex(TypeError, err_msg):
+
+            @udtf(returnType="a: int")
+            def test_udtf(a: int):
+                yield a,
+
+        def test_udtf(a: int):
+            yield a
+
+        with self.assertRaisesRegex(TypeError, err_msg):
+            udtf(test_udtf, returnType="a: int")
+
+
+class UDTFTests(UDTFTestsMixin, ReusedSQLTestCase):
+    @classmethod
+    def setUpClass(cls):
+        super(UDTFTests, cls).setUpClass()
+        cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false")
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_udtf import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
new file mode 100644
index 00000000000..e1b0825c170
--- /dev/null
+++ b/python/pyspark/sql/udtf.py
@@ -0,0 +1,233 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+User-defined table function related classes and functions
+"""
+import sys
+from typing import Type, TYPE_CHECKING, Optional, Union
+
+from py4j.java_gateway import JavaObject
+
+from pyspark.errors import PySparkTypeError
+from pyspark.sql.column import _to_java_column, _to_seq
+from pyspark.sql.types import StructType, _parse_datatype_string
+from pyspark.sql.udf import _wrap_function
+
+if TYPE_CHECKING:
+    from pyspark.sql._typing import ColumnOrName
+    from pyspark.sql.dataframe import DataFrame
+    from pyspark.sql.session import SparkSession
+
+__all__ = ["UDTFRegistration"]
+
+
+def _create_udtf(
+    cls: Type,
+    returnType: Union[StructType, str],
+    name: Optional[str] = None,
+    deterministic: bool = True,
+) -> "UserDefinedTableFunction":
+    """Create a Python UDTF."""
+    udtf_obj = UserDefinedTableFunction(
+        cls, returnType=returnType, name=name, deterministic=deterministic
+    )
+    return udtf_obj
+
+
+class UserDefinedTableFunction:
+    """
+    User-defined table function in Python
+
+    .. versionadded:: 3.5.0
+
+    Notes
+    -----
+    The constructor of this class is not supposed to be directly called.
+    Use :meth:`pyspark.sql.functions.udtf` to create this instance.
+
+    This API is evolving.
+    """
+
+    def __init__(
+        self,
+        func: Type,
+        returnType: Union[StructType, str],
+        name: Optional[str] = None,
+        deterministic: bool = True,
+    ):
+
+        if not isinstance(func, type):
+            raise PySparkTypeError(
+                f"Invalid user defined table function: the function handler "
+                f"must be a class, but got {type(func).__name__}. Please provide "
+                "a class as the handler."
+            )
+
+        # TODO(SPARK-43968): add more compile time checks for UDTFs
+        self.func = func
+        self._returnType = returnType
+        self._returnType_placeholder: Optional[StructType] = None
+        self._inputTypes_placeholder = None
+        self._judtf_placeholder = None
+        self._name = name or func.__name__
+        self.deterministic = deterministic
+
+    @property
+    def returnType(self) -> StructType:
+        # `_parse_datatype_string` accesses to JVM for parsing a DDL formatted string.
+        # This makes sure this is called after SparkContext is initialized.
+        if self._returnType_placeholder is None:
+            if isinstance(self._returnType, StructType):
+                self._returnType_placeholder = self._returnType
+            else:
+                assert isinstance(self._returnType, str)
+                parsed = _parse_datatype_string(self._returnType)
+                if not isinstance(parsed, StructType):
+                    raise PySparkTypeError(
+                        f"Invalid return type for the user defined table function "
+                        f"'{self._name}': {self._returnType}. The return type of a "
+                        f"UDTF must be a 'StructType'. Please ensure the return "
+                        "type is a correctly formatted 'StructType' string."
+                    )
+                self._returnType_placeholder = parsed
+        return self._returnType_placeholder
+
+    @property
+    def _judtf(self) -> JavaObject:
+        if self._judtf_placeholder is None:
+            self._judtf_placeholder = self._create_judtf(self.func)
+        return self._judtf_placeholder
+
+    def _create_judtf(self, func: Type) -> JavaObject:
+        from pyspark.sql import SparkSession
+
+        spark = SparkSession._getActiveSessionOrCreate()
+        sc = spark.sparkContext
+
+        wrapped_func = _wrap_function(sc, func, self.returnType)
+        jdt = spark._jsparkSession.parseDataType(self.returnType.json())
+        assert sc._jvm is not None
+        judtf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(
+            self._name, wrapped_func, jdt, self.deterministic
+        )
+        return judtf
+
+    def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
+        from pyspark.sql import DataFrame, SparkSession
+
+        spark = SparkSession._getActiveSessionOrCreate()
+        sc = spark.sparkContext
+
+        judtf = self._judtf
+        jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, cols, _to_java_column))
+        return DataFrame(jPythonUDTF, spark)
+
+    def asNondeterministic(self) -> "UserDefinedTableFunction":
+        """
+        Updates UserDefinedTableFunction to nondeterministic.
+        """
+        # Explicitly clean the cache to create a JVM UDTF instance.
+        self._judtf_placeholder = None
+        self.deterministic = False
+        return self
+
+
+class UDTFRegistration:
+    """
+    Wrapper for user-defined table function registration. This instance can be accessed by
+    :attr:`spark.udtf` or :attr:`sqlContext.udtf`.
+
+    .. versionadded:: 3.5.0
+    """
+
+    def __init__(self, sparkSession: "SparkSession"):
+        self.sparkSession = sparkSession
+
+    def register(
+        self,
+        name: str,
+        f: UserDefinedTableFunction,
+    ) -> UserDefinedTableFunction:
+        """Register a Python user-defined table function as a SQL table function.
+
+        .. versionadded:: 3.5.0
+
+        Parameters
+        ----------
+        name : str
+            The name of the user-defined table function in SQL statements.
+        f : function or :meth:`pyspark.sql.functions.udtf`
+            The user-defined table function.
+
+        Returns
+        -------
+        function
+            The registered user-defined table function.
+
+        Notes
+        -----
+        Spark uses the return type of the given user-defined table function as the return
+        type of the registered user-defined function.
+
+        To register a nondeterministic Python table function, users need to first build
+        a nondeterministic user-defined table function and then register it as a SQL function.
+
+        Examples
+        --------
+        >>> from pyspark.sql.functions import udtf
+        >>> @udtf(returnType="c1: int, c2: int")
+        ... class PlusOne:
+        ...     def eval(self, x: int):
+        ...         yield x, x + 1
+        >>> _ = spark.udtf.register(name="plus_one", f=PlusOne)
+        >>> spark.sql("SELECT * FROM plus_one(1)").collect()
+        [Row(c1=1, c2=2)]
+
+        Use it with lateral join
+
+        >>> spark.sql("SELECT * FROM VALUES (0, 1), (1, 2) t(x, y), LATERAL plus_one(x)").collect()
+        [Row(x=0, y=1, c1=0, c2=1), Row(x=1, y=2, c1=1, c2=2)]
+        """
+        register_udtf = _create_udtf(
+            f.func,
+            returnType=f.returnType,
+            name=name,
+            deterministic=f.deterministic,
+        )
+        return_udtf = f
+        self.sparkSession._jsparkSession.udtf().registerPython(name, register_udtf._judtf)
+        return return_udtf
+
+
+def _test() -> None:
+    import doctest
+    from pyspark.sql import SparkSession
+    import pyspark.sql.udf
+
+    globs = pyspark.sql.udtf.__dict__.copy()
+    spark = SparkSession.builder.master("local[4]").appName("sql.udtf tests").getOrCreate()
+    globs["spark"] = spark
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.sql.udtf, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
+    )
+    spark.stop()
+    if failure_count:
+        sys.exit(-1)
+
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 06f0d1dc37f..77d0d548408 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -64,7 +64,7 @@ from pyspark.sql.pandas.types import to_arrow_type
 from pyspark.sql.types import StructType
 from pyspark.util import fail_on_stopiteration, try_simplify_traceback
 from pyspark import shuffle
-from pyspark.errors import PySparkRuntimeError
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
 
 pickleSer = CPickleSerializer()
 utf8_deserializer = UTF8Deserializer()
@@ -456,6 +456,86 @@ def assign_cols_by_name(runner_conf):
     )
 
 
+# Read and process a serialized user-defined table function (UDTF) from a socket.
+# It expects the UDTF to be in a specific format and performs various checks to
+# ensure the UDTF is valid. This function also prepares a mapper function for applying
+# the UDTF logic to input rows.
+def read_udtf(pickleSer, infile, eval_type):
+    num_udtfs = read_int(infile)
+    if num_udtfs != 1:
+        raise PySparkValueError(f"Unexpected number of UDTFs. Expected 1 but got {num_udtfs}.")
+
+    # See `PythonUDFRunner.writeUDFs`.
+    num_arg = read_int(infile)
+    arg_offsets = [read_int(infile) for _ in range(num_arg)]
+    num_chained_funcs = read_int(infile)
+    if num_chained_funcs != 1:
+        raise PySparkValueError(
+            f"Unexpected number of chained UDTFs. Expected 1 but got {num_chained_funcs}."
+        )
+
+    handler, return_type = read_command(pickleSer, infile)
+    if not isinstance(handler, type):
+        raise PySparkRuntimeError(
+            f"Invalid UDTF handler type. Expected a class (type 'type'), but "
+            f"got an instance of {type(handler).__name__}."
+        )
+
+    # Instantiate the UDTF class.
+    try:
+        udtf = handler()
+    except Exception as e:
+        raise PySparkRuntimeError(
+            f"User defined table function encountered an error in "
+            f"the '__init__' method: {str(e)}"
+        )
+
+    # Validate the UDTF
+    if not hasattr(udtf, "eval"):
+        raise PySparkRuntimeError(
+            "Failed to execute the user defined table function because it has not "
+            "implemented the 'eval' method. Please add the 'eval' method and try "
+            "the query again."
+        )
+
+    # Wrap the UDTF and convert the results.
+    def wrap_udtf(f, return_type):
+        if return_type.needConversion():
+            toInternal = return_type.toInternal
+            return lambda *a: map(toInternal, f(*a))
+        else:
+            return lambda *a: f(*a)
+
+    eval = wrap_udtf(getattr(udtf, "eval"), return_type)
+
+    if hasattr(udtf, "terminate"):
+        terminate = wrap_udtf(getattr(udtf, "terminate"), return_type)
+    else:
+        terminate = None
+
+    def mapper(a):
+        results = tuple(eval(*[a[o] for o in arg_offsets]))
+        return results
+
+    # Return an iterator of iterators.
+    def func(_, it):
+        try:
+            yield from map(mapper, it)
+        finally:
+            if terminate is not None:
+                try:
+                    yield tuple(terminate())
+                except BaseException as e:
+                    raise PySparkRuntimeError(
+                        f"User defined table function encountered an error in "
+                        f"the 'terminate' method: {str(e)}"
+                    )
+
+    ser = BatchedSerializer(CPickleSerializer(), 100)
+
+    return func, None, ser, ser
+
+
 def read_udfs(pickleSer, infile, eval_type):
     runner_conf = {}
 
@@ -859,6 +939,8 @@ def main(infile, outfile):
         eval_type = read_int(infile)
         if eval_type == PythonEvalType.NON_UDF:
             func, profiler, deserializer, serializer = read_command(pickleSer, infile)
+        elif eval_type == PythonEvalType.SQL_TABLE_UDF:
+            func, profiler, deserializer, serializer = read_udtf(pickleSer, infile, eval_type)
         else:
             func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)
 
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
index be2b3dbe819..84ee30440b7 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
@@ -48,7 +48,8 @@ public class ExpressionInfo {
             "window_funcs", "xml_funcs", "table_funcs", "url_funcs"));
 
     private static final Set<String> validSources =
-            new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf", "java_udf"));
+            new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf",
+                    "java_udf", "python_udtf"));
 
     public String getClassName() {
         return className;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index 8636eb61034..6905bde9c33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -57,8 +57,6 @@ trait PythonFuncExpression extends NonSQLExpression with UserDefinedExpression {
   def udfDeterministic: Boolean
   def resultId: ExprId
 
-  final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)
-
   override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
 
   override def toString: String = s"$name(${children.mkString(", ")})#${resultId.id}$typeSuffix"
@@ -89,6 +87,8 @@ case class PythonUDF(
     this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
   }
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)
+
   override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDF =
     copy(children = newChildren)
 }
@@ -138,10 +138,45 @@ case class PythonUDAF(
     this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
   }
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)
+
   override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDAF =
     copy(children = newChildren)
 }
 
+abstract class UnevaluableGenerator extends Generator {
+  final override def eval(input: InternalRow): TraversableOnce[InternalRow] =
+    throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
+
+  final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+    throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
+}
+
+/**
+ * A serialized version of a Python table-valued function. This is a special expression,
+ * which needs a dedicated physical operator to execute it.
+ */
+case class PythonUDTF(
+    name: String,
+    func: PythonFunction,
+    override val elementSchema: StructType,
+    children: Seq[Expression],
+    udfDeterministic: Boolean,
+    resultId: ExprId = NamedExpression.newExprId)
+  extends UnevaluableGenerator with PythonFuncExpression {
+
+  override def evalType: Int = PythonEvalType.SQL_TABLE_UDF
+
+  override lazy val canonicalized: Expression = {
+    val canonicalizedChildren = children.map(_.canonicalized)
+    // `resultId` can be seen as cosmetic variation in PythonUDTF, as it doesn't affect the result.
+    this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
+  }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDTF =
+    copy(children = newChildren)
+}
+
 /**
  * A place holder used when printing expressions without debugging information such as the
  * result id.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala
index d7e8b8a9610..db89fc4bbbd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala
@@ -48,7 +48,8 @@ object PlanHelper {
                plan.isInstanceOf[CollectMetrics] ||
                onlyInLateralSubquery(plan)) => e
         case e: Generator
-          if !plan.isInstanceOf[Generate] => e
+          if !(plan.isInstanceOf[Generate] ||
+               plan.isInstanceOf[BaseEvalPythonUDTF]) => e
       }
     }
     invalidExpressions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index fe5eee481be..f79c360a313 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, PythonUDTF}
 import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
@@ -148,6 +148,21 @@ trait BaseEvalPython extends UnaryNode {
   final override val nodePatterns: Seq[TreePattern] = Seq(EVAL_PYTHON_UDF)
 }
 
+trait BaseEvalPythonUDTF extends UnaryNode {
+
+  def udtf: PythonUDTF
+
+  def requiredChildOutput: Seq[Attribute]
+
+  def resultAttrs: Seq[Attribute]
+
+  override def output: Seq[Attribute] = requiredChildOutput ++ resultAttrs
+
+  override def producedAttributes: AttributeSet = AttributeSet(resultAttrs)
+
+  final override val nodePatterns: Seq[TreePattern] = Seq(EVAL_PYTHON_UDTF)
+}
+
 /**
  * A logical plan that evaluates a [[PythonUDF]]
  */
@@ -171,6 +186,24 @@ case class ArrowEvalPython(
     copy(child = newChild)
 }
 
+/**
+ * A logical plan that evaluates a [[PythonUDTF]].
+ *
+ * @param udtf the user-defined Python function
+ * @param requiredChildOutput the required output of the child plan. It's used for omitting data
+ *                            generation that will be discarded next by a projection.
+ * @param resultAttrs the output schema of the Python UDTF.
+ * @param child the child plan
+ */
+case class BatchEvalPythonUDTF(
+    udtf: PythonUDTF,
+    requiredChildOutput: Seq[Attribute],
+    resultAttrs: Seq[Attribute],
+    child: LogicalPlan) extends BaseEvalPythonUDTF {
+  override protected def withNewChildInternal(newChild: LogicalPlan): BatchEvalPythonUDTF =
+    copy(child = newChild)
+}
+
 /**
  * A logical plan that adds a new long column with the name `name` that
  * increases one by one. This is for 'distributed-sequence' default index
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 230d2b04c4f..11d5cf54df4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -106,6 +106,7 @@ object TreePattern extends Enumeration  {
   val CTE: Value = Value
   val DISTINCT_LIKE: Value = Value
   val EVAL_PYTHON_UDF: Value = Value
+  val EVAL_PYTHON_UDTF: Value = Value
   val EVENT_TIME_WATERMARK: Value = Value
   val EXCEPT: Value = Value
   val FILTER: Value = Value
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 66215d05033..642006fb8dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -225,6 +225,8 @@ class SparkSession private(
    */
   def udf: UDFRegistration = sessionState.udfRegistration
 
+  def udtf: UDTFRegistration = sessionState.udtfRegistration
+
   /**
    * Returns a `StreamingQueryManager` that allows managing all the
    * `StreamingQuery`s active on `this`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDTFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDTFRegistration.scala
new file mode 100644
index 00000000000..3330597cb76
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDTFRegistration.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Evolving
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry
+import org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction
+
+/**
+ * Functions for registering user-defined table functions. Use `SparkSession.udtf` to access this.
+ *
+ * @since 3.5.0
+ */
+@Evolving
+class UDTFRegistration private[sql] (tableFunctionRegistry: TableFunctionRegistry)
+  extends Logging {
+
+  protected[sql] def registerPython(name: String, udtf: UserDefinedPythonTableFunction): Unit = {
+    log.debug(
+      s"""
+         | Registering new PythonUDTF:
+         | name: $name
+         | command: ${udtf.func.command}
+         | envVars: ${udtf.func.envVars}
+         | pythonIncludes: ${udtf.func.pythonIncludes}
+         | pythonExec: ${udtf.func.pythonExec}
+         | returnType: ${udtf.returnType}
+         | udfDeterministic: ${udtf.udfDeterministic}
+      """.stripMargin)
+
+    tableFunctionRegistry.createOrReplaceTempFunction(name, udtf.builder, "python_udtf")
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index f05fe9d60fb..70a35ea9115 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
 import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
 import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
-import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
+import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs}
 
 class SparkOptimizer(
     catalogManager: CatalogManager,
@@ -77,6 +77,7 @@ class SparkOptimizer(
       // This must be executed after `ExtractPythonUDFFromAggregate` and before `ExtractPythonUDFs`.
       ExtractGroupingPythonUDFFromAggregate,
       ExtractPythonUDFs,
+      ExtractPythonUDTFs,
       // The eval-python node may be between Project/Filter and the scan node, which breaks
       // column pruning and filter push-down. Here we rerun the related optimizer rules.
       ColumnPruning,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index a414b815018..a4cca1a248b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -747,6 +747,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         ArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil
       case BatchEvalPython(udfs, output, child) =>
         BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil
+      case BatchEvalPythonUDTF(udtf, requiredChildOutput, resultAttrs, child) =>
+        BatchEvalPythonUDTFExec(udtf, requiredChildOutput, resultAttrs, planLater(child)) :: Nil
       case _ =>
         Nil
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index ca7ca2e2f80..c8a798d5b70 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -42,39 +42,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
       context: TaskContext): Iterator[InternalRow] = {
     EvaluatePython.registerPicklers()  // register pickler for Row
 
-    val dataTypes = schema.map(_.dataType)
-    val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
-
-    // enable memo iff we serialize the row with schema (schema and class should be memorized)
-    // pyrolite 4.21+ can lookup objects in its cache by value, but `GenericRowWithSchema` objects,
-    // that we pass from JVM to Python, don't define their `equals()` to take the type of the
-    // values or the schema of the row into account. This causes like
-    // `GenericRowWithSchema(Array(1.0, 1.0),
-    //    StructType(Seq(StructField("_1", DoubleType), StructField("_2", DoubleType))))`
-    // and
-    // `GenericRowWithSchema(Array(1, 1),
-    //    StructType(Seq(StructField("_1", IntegerType), StructField("_2", IntegerType))))`
-    // to be `equal()` and so we need to disable this feature explicitly (`valueCompare=false`).
-    // Please note that cache by reference is still enabled depending on `needConversion`.
-    val pickle = new Pickler(/* useMemo = */ needConversion,
-      /* valueCompare = */ false)
-    // Input iterator to Python: input rows are grouped so we send them in batches to Python.
-    // For each row, add it to the queue.
-    val inputIterator = iter.map { row =>
-      if (needConversion) {
-        EvaluatePython.toJava(row, schema)
-      } else {
-        // fast path for these types that does not need conversion in Python
-        val fields = new Array[Any](row.numFields)
-        var i = 0
-        while (i < row.numFields) {
-          val dt = dataTypes(i)
-          fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
-          i += 1
-        }
-        fields
-      }
-    }.grouped(100).map(x => pickle.dumps(x.toArray))
+    // Input iterator to Python.
+    val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema)
 
     // Output iterator for results from Python.
     val outputIterator =
@@ -109,3 +78,43 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
   override protected def withNewChildInternal(newChild: SparkPlan): BatchEvalPythonExec =
     copy(child = newChild)
 }
+
+object BatchEvalPythonExec {
+  def getInputIterator(
+      iter: Iterator[InternalRow],
+      schema: StructType): Iterator[Array[Byte]] = {
+    val dataTypes = schema.map(_.dataType)
+    val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
+
+    // enable memo iff we serialize the row with schema (schema and class should be memorized)
+    // pyrolite 4.21+ can lookup objects in its cache by value, but `GenericRowWithSchema` objects,
+    // that we pass from JVM to Python, don't define their `equals()` to take the type of the
+    // values or the schema of the row into account. This causes like
+    // `GenericRowWithSchema(Array(1.0, 1.0),
+    //    StructType(Seq(StructField("_1", DoubleType), StructField("_2", DoubleType))))`
+    // and
+    // `GenericRowWithSchema(Array(1, 1),
+    //    StructType(Seq(StructField("_1", IntegerType), StructField("_2", IntegerType))))`
+    // to be `equal()` and so we need to disable this feature explicitly (`valueCompare=false`).
+    // Please note that cache by reference is still enabled depending on `needConversion`.
+    val pickle = new Pickler(/* useMemo = */ needConversion,
+      /* valueCompare = */ false)
+    // Input iterator to Python: input rows are grouped so we send them in batches to Python.
+    // For each row, add it to the queue.
+    iter.map { row =>
+      if (needConversion) {
+        EvaluatePython.toJava(row, schema)
+      } else {
+        // fast path for these types that does not need conversion in Python
+        val fields = new Array[Any](row.numFields)
+        var i = 0
+        while (i < row.numFields) {
+          val dt = dataTypes(i)
+          fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
+          i += 1
+        }
+        fields
+      }
+    }.grouped(100).map(x => pickle.dumps(x.toArray))
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
new file mode 100644
index 00000000000..a7fdfb9d173
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import net.razorvine.pickle.Unpickler
+
+import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+/**
+ * A physical plan that evaluates a [[PythonUDTF]]. This is similar to [[BatchEvalPythonExec]].
+ *
+ * @param udtf the user-defined Python function
+ * @param requiredChildOutput the required output of the child plan. It's used for omitting data
+ *                            generation that will be discarded next by a projection.
+ * @param resultAttrs the output schema of the Python UDTF.
+ * @param child the child plan
+ */
+case class BatchEvalPythonUDTFExec(
+    udtf: PythonUDTF,
+    requiredChildOutput: Seq[Attribute],
+    resultAttrs: Seq[Attribute],
+    child: SparkPlan)
+  extends UnaryExecNode with PythonSQLMetrics {
+
+  override def output: Seq[Attribute] = requiredChildOutput ++ resultAttrs
+
+  override def producedAttributes: AttributeSet = AttributeSet(resultAttrs)
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val inputRDD = child.execute().map(_.copy())
+
+    inputRDD.mapPartitions { iter =>
+      val context = TaskContext.get()
+      val contextAwareIterator = new ContextAwareIterator(context, iter)
+
+      // The queue used to buffer input rows so we can drain it to
+      // combine input with output from Python.
+      val queue = HybridRowQueue(context.taskMemoryManager(),
+        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+      context.addTaskCompletionListener[Unit] { ctx =>
+        queue.close()
+      }
+
+      val inputs = Seq(udtf.children)
+
+      // flatten all the arguments
+      val allInputs = new ArrayBuffer[Expression]
+      val dataTypes = new ArrayBuffer[DataType]
+      val argOffsets = inputs.map { input =>
+        input.map { e =>
+          if (allInputs.exists(_.semanticEquals(e))) {
+            allInputs.indexWhere(_.semanticEquals(e))
+          } else {
+            allInputs += e
+            dataTypes += e.dataType
+            allInputs.length - 1
+          }
+        }.toArray
+      }.toArray
+      val projection = MutableProjection.create(allInputs.toSeq, child.output)
+      projection.initialize(context.partitionId())
+      val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
+        StructField(s"_$i", dt)
+      }.toArray)
+
+      // Add rows to the queue to join later with the result.
+      // Also keep track of the number rows added to the queue.
+      // This is needed to process extra output rows from the `terminate()` call of the UDTF.
+      var count = 0L
+      val projectedRowIter = contextAwareIterator.map { inputRow =>
+        queue.add(inputRow.asInstanceOf[UnsafeRow])
+        count += 1
+        projection(inputRow)
+      }
+
+      val outputRowIterator = evaluate(udtf, argOffsets, projectedRowIter, schema, context)
+
+      val pruneChildForResult: InternalRow => InternalRow =
+        if (child.outputSet == AttributeSet(requiredChildOutput)) {
+          identity
+        } else {
+          UnsafeProjection.create(requiredChildOutput, child.output)
+        }
+
+      val joined = new JoinedRow
+      val resultProj = UnsafeProjection.create(output, output)
+
+      outputRowIterator.flatMap { outputRows =>
+        // If `count` is greater than zero, it means there are remaining input rows in the queue.
+        // In this case, the output rows of the UDTF are joined with the corresponding input row
+        // in the queue.
+        if (count > 0) {
+          val left = queue.remove()
+          count -= 1
+          joined.withLeft(pruneChildForResult(left))
+        }
+        // If `count` is zero, it means all input rows have been consumed. Any additional rows
+        // from the UDTF are from the `terminate()` call. We leave the left side as the last
+        // element of its child output to keep it consistent with the Generate implementation
+        // and Hive UDTFs.
+        outputRows.map(r => resultProj(joined.withRight(r)))
+      }
+    }
+  }
+
+  /**
+   * Evaluates a Python UDTF. It computes the results using the PythonUDFRunner, and returns
+   * an iterator of internal rows for every input row.
+   */
+  private def evaluate(
+      udtf: PythonUDTF,
+      argOffsets: Array[Array[Int]],
+      iter: Iterator[InternalRow],
+      schema: StructType,
+      context: TaskContext): Iterator[Iterator[InternalRow]] = {
+    EvaluatePython.registerPicklers()  // register pickler for Row
+
+    // Input iterator to Python.
+    val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema)
+
+    // Output iterator for results from Python.
+    val funcs = Seq(ChainedPythonFunctions(Seq(udtf.func)))
+    val outputIterator =
+      new PythonUDFRunner(funcs, PythonEvalType.SQL_TABLE_UDF, argOffsets, pythonMetrics)
+        .compute(inputIterator, context.partitionId(), context)
+
+    val unpickle = new Unpickler
+
+    // The return type of a UDTF is an array of struct.
+    val resultType = udtf.dataType
+    val fromJava = EvaluatePython.makeFromJava(resultType)
+
+    outputIterator.flatMap { pickedResult =>
+      val unpickledBatch = unpickle.loads(pickedResult)
+      unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
+    }.map { results =>
+      assert(results.getClass.isArray)
+      val res = results.asInstanceOf[Array[_]]
+      pythonMetrics("pythonNumRowsReceived") += res.length
+      fromJava(results).asInstanceOf[GenericArrayData]
+        .array.map(_.asInstanceOf[InternalRow]).toIterator
+    }
+  }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): BatchEvalPythonUDTFExec =
+    copy(child = newChild)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 57c3e1ad88e..f6993eaf6bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -305,3 +305,19 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] {
     }
   }
 }
+
+/**
+ * Extracts PythonUDTFs from operators, rewriting the query plan so that UDTFs can be evaluated.
+ */
+object ExtractPythonUDTFs extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan match {
+    // A correlated subquery will be rewritten into join later, and will go through this rule
+    // eventually. Here we skip subquery, as Python UDTFs only need to be extracted once.
+    case s: Subquery if s.correlated => plan
+
+    case _ => plan.transformUpWithPruning(_.containsPattern(GENERATE)) {
+      case g @ Generate(func: PythonUDTF, _, _, _, _, child) =>
+        BatchEvalPythonUDTF(func, g.requiredChildOutput, g.generatorOutput, child)
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index bc76eaed04b..08c7c8730d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -18,9 +18,10 @@
 package org.apache.spark.sql.execution.python
 
 import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
-import org.apache.spark.sql.Column
-import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDAF, PythonUDF}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDAF, PythonUDF, PythonUDTF}
+import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.types.{DataType, StructType}
 
 /**
  * A user-defined Python function. This is used by the Python API.
@@ -55,3 +56,36 @@ case class UserDefinedPythonFunction(
     }
   }
 }
+
+/**
+ * A user-defined Python table function. This is used by the Python API.
+ */
+case class UserDefinedPythonTableFunction(
+    name: String,
+    func: PythonFunction,
+    returnType: StructType,
+    udfDeterministic: Boolean) {
+
+  def builder(e: Seq[Expression]): LogicalPlan = {
+    val udtf = PythonUDTF(
+      name = name,
+      func = func,
+      elementSchema = returnType,
+      children = e,
+      udfDeterministic = udfDeterministic)
+    Generate(
+      udtf,
+      unrequiredChildIndex = Nil,
+      outer = false,
+      qualifier = None,
+      generatorOutput = Nil,
+      child = OneRowRelation()
+    )
+  }
+
+  /** Returns a [[DataFrame]] that will evaluate to calling this UDTF with the given input. */
+  def apply(session: SparkSession, exprs: Column*): DataFrame = {
+    val udtf = builder(exprs.map(_.expr))
+    Dataset.ofRows(session, udtf)
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 460e5b68ff8..bc4b308f75d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -175,6 +175,8 @@ abstract class BaseSessionStateBuilder(
    */
   protected def udfRegistration: UDFRegistration = new UDFRegistration(functionRegistry)
 
+  protected def udtfRegistration: UDTFRegistration = new UDTFRegistration(tableFunctionRegistry)
+
   /**
    * Logical query plan analyzer for resolving unresolved attributes and relations.
    *
@@ -365,6 +367,7 @@ abstract class BaseSessionStateBuilder(
       functionRegistry,
       tableFunctionRegistry,
       udfRegistration,
+      udtfRegistration,
       () => catalog,
       sqlParser,
       () => analyzer,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index eb0b71d155b..177a25b45fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -46,6 +46,8 @@ import org.apache.spark.util.{DependencyUtils, Utils}
  * @param experimentalMethods Interface to add custom planning strategies and optimizers.
  * @param functionRegistry Internal catalog for managing functions registered by the user.
  * @param udfRegistration Interface exposed to the user for registering user-defined functions.
+ * @param udtfRegistration Interface exposed to the user for registering user-defined
+ *                         table functions.
  * @param catalogBuilder a function to create an internal catalog for managing table and database
  *                       states.
  * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
@@ -69,6 +71,7 @@ private[sql] class SessionState(
     val functionRegistry: FunctionRegistry,
     val tableFunctionRegistry: TableFunctionRegistry,
     val udfRegistration: UDFRegistration,
+    val udtfRegistration: UDTFRegistration,
     catalogBuilder: () => SessionCatalog,
     val sqlParser: ParserInterface,
     analyzerBuilder: () => Analyzer,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index c95b5f7c27f..f3d8e883f7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -30,9 +30,9 @@ import org.apache.spark.api.python.{PythonBroadcast, PythonEvalType, PythonFunct
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, PythonUDF}
 import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
+import org.apache.spark.sql.execution.python.{UserDefinedPythonFunction, UserDefinedPythonTableFunction}
 import org.apache.spark.sql.expressions.SparkUserDefinedFunction
-import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType}
+import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType, StructType}
 
 /**
  * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF,
@@ -191,6 +191,32 @@ object IntegratedUDFTestUtils extends SQLHelper {
     throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
   }
 
+  private def createPythonUDTF(funcName: String, pythonScript: String): Array[Byte] = {
+    if (!shouldTestPythonUDFs) {
+      throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
+    }
+    var binaryPandasFunc: Array[Byte] = null
+    withTempPath { codePath =>
+      Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8))
+      withTempPath { path =>
+        Process(
+          Seq(
+            pythonExec,
+            "-c",
+            "from pyspark.serializers import CloudPickleSerializer; " +
+              s"f = open('$path', 'wb');" +
+              s"exec(open('$codePath', 'r').read());" +
+              "f.write(CloudPickleSerializer().dumps(" +
+              s"($funcName, returnType)))"),
+          None,
+          "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
+        binaryPandasFunc = Files.readAllBytes(path.toPath)
+      }
+    }
+    assert(binaryPandasFunc != null)
+    binaryPandasFunc
+  }
+
   private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) {
     var binaryPandasFunc: Array[Byte] = null
     withTempPath { path =>
@@ -360,6 +386,25 @@ object IntegratedUDFTestUtils extends SQLHelper {
     val prettyName: String = "Regular Python UDF"
   }
 
+  def createUserDefinedPythonTableFunction(
+      name: String,
+      pythonScript: String,
+      returnType: StructType,
+      deterministic: Boolean = true): UserDefinedPythonTableFunction = {
+    UserDefinedPythonTableFunction(
+      name = name,
+      func = SimplePythonFunction(
+        command = createPythonUDTF(name, pythonScript),
+        envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
+        pythonIncludes = List.empty[String].asJava,
+        pythonExec = pythonExec,
+        pythonVer = pythonVer,
+        broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
+        accumulator = null),
+      returnType = returnType,
+      udfDeterministic = deterministic)
+  }
+
   /**
    * A Scalar Pandas UDF that takes one column, casts into string, executes the
    * Python native function, and casts back to the type of input column.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
new file mode 100644
index 00000000000..fa4d80c331a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest, Row}
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StructType
+
+class PythonUDTFSuite extends QueryTest with SharedSparkSession {
+
+  import testImplicits._
+
+  import IntegratedUDFTestUtils._
+
+  private val pythonScript: String =
+    """
+      |from pyspark.sql.types import StructType, StructField, IntegerType
+      |returnType = StructType([
+      |  StructField("a", IntegerType()),
+      |  StructField("b", IntegerType()),
+      |  StructField("c", IntegerType()),
+      |])
+      |class SimpleUDTF:
+      |    def eval(self, a: int, b: int):
+      |        yield a, b, a + b
+      |        yield a, b, a - b
+      |        yield a, b, b - a
+      |""".stripMargin
+
+  private val returnType: StructType = StructType.fromDDL("a int, b int, c int")
+
+  private val pythonUDTF: UserDefinedPythonTableFunction =
+    createUserDefinedPythonTableFunction("SimpleUDTF", pythonScript, returnType)
+
+  test("Simple PythonUDTF") {
+    assume(shouldTestPythonUDFs)
+    val df = pythonUDTF(spark, lit(1), lit(2))
+    checkAnswer(df, Seq(Row(1, 2, -1), Row(1, 2, 1), Row(1, 2, 3)))
+  }
+
+  test("PythonUDTF with lateral join") {
+    assume(shouldTestPythonUDFs)
+    withTempView("t") {
+      spark.udtf.registerPython("testUDTF", pythonUDTF)
+      Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t")
+      checkAnswer(
+        sql("SELECT f.* FROM t, LATERAL testUDTF(a, b) f"),
+        sql("SELECT * FROM t, LATERAL explode(array(a + b, a - b, b - a)) t(c)"))
+    }
+  }
+
+  test("PythonUDTF in correlated subquery") {
+    assume(shouldTestPythonUDFs)
+    withTempView("t") {
+      spark.udtf.registerPython("testUDTF", pythonUDTF)
+      Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t")
+      checkAnswer(
+        sql("SELECT (SELECT sum(f.b) AS r FROM testUDTF(1, 2) f WHERE f.a = t.a) FROM t"),
+        Seq(Row(6), Row(null)))
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org